src/HOL/BNF/Tools/bnf_lfp_compat.ML
author blanchet
Tue Oct 01 14:05:25 2013 +0200 (2013-10-01)
changeset 54006 9fe1bd54d437
parent 54003 c4343c31f86d
child 54009 f138452e8265
permissions -rw-r--r--
renamed theory file
     1 (*  Title:      HOL/BNF/Tools/bnf_lfp_compat.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2013
     4 
     5 Compatibility layer with the old datatype package.
     6 *)
     7 
     8 signature BNF_LFP_COMPAT =
     9 sig
    10   val datatype_new_compat_cmd : string list -> local_theory -> local_theory
    11 end;
    12 
    13 structure BNF_LFP_Compat : BNF_LFP_COMPAT =
    14 struct
    15 
    16 open Ctr_Sugar
    17 open BNF_Util
    18 open BNF_FP_Util
    19 open BNF_FP_Def_Sugar
    20 open BNF_FP_N2M_Sugar
    21 
    22 fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
    23   | dtyp_of_typ recTs (T as Type (s, Ts)) =
    24     (case find_index (curry (op =) T) recTs of
    25       ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
    26     | kk => Datatype_Aux.DtRec kk);
    27 
    28 val compatN = "compat_";
    29 
    30 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
    31 fun datatype_new_compat_cmd raw_fpT_names lthy =
    32   let
    33     val thy = Proof_Context.theory_of lthy;
    34 
    35     fun not_datatype s = error (quote s ^ " is not a new-style datatype");
    36     fun not_mutually_recursive ss =
    37       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
    38 
    39     val (fpT_names as fpT_name1 :: _) =
    40       map (fst o dest_Type o Proof_Context.read_type_name_proper lthy false) raw_fpT_names;
    41 
    42     val Ss = Sign.arity_sorts thy fpT_name1 HOLogic.typeS;
    43 
    44     val (unsorted_As, _) = lthy |> mk_TFrees (length Ss);
    45     val As = map2 resort_tfree Ss unsorted_As;
    46 
    47     fun lfp_sugar_of s =
    48       (case fp_sugar_of lthy s of
    49         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
    50       | _ => not_datatype s);
    51 
    52     val fp_sugar0 as {fp_res = {Ts = fpTs0, ...}, ...} = lfp_sugar_of fpT_name1;
    53     val fpT_names' = map (fst o dest_Type) fpTs0;
    54 
    55     val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
    56 
    57     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
    58 
    59     fun add_nested_types_of (T as Type (s, _)) seen =
    60       if member (op =) seen T orelse s = @{type_name fun} then
    61         seen
    62       else
    63         (case try lfp_sugar_of s of
    64           SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
    65           let
    66             val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
    67             val substT = Term.typ_subst_TVars rho;
    68 
    69             val mutual_Ts = map substT mutual_Ts0;
    70 
    71             fun add_interesting_subtypes (U as Type (_, Us)) =
    72                 (case filter (exists_subtype_in mutual_Ts) Us of [] => I
    73                 | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
    74               | add_interesting_subtypes _ = I;
    75 
    76             val ctrs = maps #ctrs ctr_sugars;
    77             val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
    78             val subTs = fold add_interesting_subtypes ctr_Ts [];
    79           in
    80             fold add_nested_types_of subTs (seen @ mutual_Ts)
    81           end
    82         | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
    83             " not associated with new-style datatype (cf. \"datatype_new\")"));
    84 
    85     val Ts = add_nested_types_of fpT1 [];
    86     val b_names = map base_name_of_typ Ts;
    87     val compat_b_names = map (prefix compatN) b_names;
    88     val compat_bs = map Binding.name compat_b_names;
    89     val common_name = compatN ^ mk_common_name b_names;
    90     val nn_fp = length fpTs;
    91     val nn = length Ts;
    92     val get_indices = K [];
    93     val fp_sugars0 = if nn = 1 then [fp_sugar0] else map (lfp_sugar_of o fst o dest_Type) Ts;
    94     val callssss = pad_and_indexify_calls fp_sugars0 nn [];
    95     val has_nested = nn > nn_fp;
    96 
    97     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
    98       mutualize_fp_sugars false has_nested Least_FP compat_bs Ts get_indices callssss fp_sugars0
    99         lthy;
   100 
   101     val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
   102       fp_sugars;
   103     val inducts = conj_dests nn induct;
   104 
   105     val frozen_Ts = map Type.legacy_freeze_type Ts;
   106     val mk_dtyp = dtyp_of_typ frozen_Ts;
   107 
   108     fun mk_ctr_descr (Const (s, T)) =
   109       (s, map mk_dtyp (binder_types (Type.legacy_freeze_type T)));
   110     fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
   111       (index, (T_name, map mk_dtyp Ts, map mk_ctr_descr ctrs));
   112 
   113     val descr = map3 mk_typ_descr (0 upto nn - 1) frozen_Ts ctr_sugars;
   114     val recs = map (fst o dest_Const o co_rec_of) co_iterss;
   115     val rec_thms = flat (map co_rec_of iter_thmsss);
   116 
   117     fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
   118       let
   119         val {casex, exhaust, nchotomy, injects, distincts, case_thms, case_cong, weak_case_cong,
   120           split, split_asm, ...} = nth ctr_sugars index;
   121       in
   122         (T_name0,
   123          {index = index, descr = descr, inject = injects, distinct = distincts, induct = induct,
   124          inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
   125          rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
   126          case_cong = case_cong, weak_case_cong = weak_case_cong, split = split,
   127          split_asm = split_asm})
   128       end;
   129 
   130     val infos = map mk_info (take nn_fp fp_sugars);
   131 
   132     val all_notes =
   133       (case lfp_sugar_thms of
   134         NONE => []
   135       | SOME ((induct_thms, induct_thm, induct_attrs), (fold_thmss, rec_thmss, _)) =>
   136         let
   137           val common_notes =
   138             (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
   139             |> filter_out (null o #2)
   140             |> map (fn (thmN, thms, attrs) =>
   141               ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   142 
   143           val notes =
   144             [(foldN, fold_thmss, []),
   145              (inductN, map single induct_thms, induct_attrs),
   146              (recN, rec_thmss, [])]
   147             |> filter_out (null o #2)
   148             |> maps (fn (thmN, thmss, attrs) =>
   149               if forall null thmss then
   150                 []
   151               else
   152                 map2 (fn b_name => fn thms =>
   153                     ((Binding.qualify true b_name (Binding.name thmN), attrs), [(thms, [])]))
   154                   compat_b_names thmss);
   155         in
   156           common_notes @ notes
   157         end);
   158 
   159     val register_interpret =
   160       Datatype_Data.register infos
   161       #> Datatype_Data.interpretation_data (Datatype_Aux.default_config, map fst infos)
   162   in
   163     lthy
   164     |> Local_Theory.raw_theory register_interpret
   165     |> Local_Theory.notes all_notes |> snd
   166   end;
   167 
   168 val _ =
   169   Outer_Syntax.local_theory @{command_spec "datatype_new_compat"}
   170     "register a new-style datatype as an old-style datatype"
   171     (Scan.repeat1 Parse.type_const >> datatype_new_compat_cmd);
   172 
   173 end;