src/HOL/BNF/Tools/bnf_lfp_compat.ML
changeset 53303 ae49b835ca01
child 53309 42a99f732a40
equal deleted inserted replaced
53302:98fdf6c34142 53303:ae49b835ca01
       
     1 (*  Title:      HOL/BNF/Tools/bnf_lfp_compat.ML
       
     2     Author:     Jasmin Blanchette, TU Muenchen
       
     3     Copyright   2013
       
     4 
       
     5 Compatibility layer with the old datatype package.
       
     6 *)
       
     7 
       
     8 signature BNF_LFP_COMPAT =
       
     9 sig
       
    10   val datatype_compat_cmd : string list -> local_theory -> local_theory
       
    11 end;
       
    12 
       
    13 structure BNF_LFP_Compat : BNF_LFP_COMPAT =
       
    14 struct
       
    15 
       
    16 open BNF_Util
       
    17 open BNF_FP_Util
       
    18 open BNF_FP_Def_Sugar
       
    19 open BNF_FP_N2M_Sugar
       
    20 
       
    21 fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
       
    22   | dtyp_of_typ recTs (T as Type (s, Ts)) =
       
    23     (case find_index (curry (op =) T) recTs of
       
    24       ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
       
    25     | kk => Datatype_Aux.DtRec kk);
       
    26 
       
    27 val compatN = "compat_";
       
    28 
       
    29 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
       
    30 fun datatype_compat_cmd raw_fpT_names lthy =
       
    31   let
       
    32     val thy = Proof_Context.theory_of lthy;
       
    33 
       
    34     fun not_datatype s = error (quote s ^ " is not a new-style datatype");
       
    35     fun not_mutually_recursive ss =
       
    36       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
       
    37 
       
    38     val (fpT_names as fpT_name1 :: _) =
       
    39       map (fst o dest_Type o Proof_Context.read_type_name_proper lthy false) raw_fpT_names;
       
    40 
       
    41     val Ss = Sign.arity_sorts thy fpT_name1 HOLogic.typeS;
       
    42 
       
    43     val (unsorted_As, _) = lthy |> mk_TFrees (length Ss);
       
    44     val As = map2 resort_tfree Ss unsorted_As;
       
    45 
       
    46     fun lfp_sugar_of s =
       
    47       (case fp_sugar_of lthy s of
       
    48         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
       
    49       | _ => not_datatype s);
       
    50 
       
    51     val fp_sugar0 as {fp_res = {Ts = fpTs0, ...}, ...} = lfp_sugar_of fpT_name1;
       
    52     val fpT_names' = map (fst o dest_Type) fpTs0;
       
    53 
       
    54     val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
       
    55 
       
    56     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
       
    57 
       
    58     fun add_nested_types_of (T as Type (s, _)) seen =
       
    59       if member (op =) seen T orelse s = @{type_name fun} then
       
    60         seen
       
    61       else
       
    62         (case try lfp_sugar_of s of
       
    63           SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
       
    64           let
       
    65             val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
       
    66             val substT = Term.typ_subst_TVars rho;
       
    67 
       
    68             val mutual_Ts = map substT mutual_Ts0;
       
    69 
       
    70             fun add_interesting_subtypes (U as Type (s, Us)) =
       
    71                 (case filter (exists_subtype_in mutual_Ts) Us of [] => I
       
    72                 | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
       
    73               | add_interesting_subtypes _ = I;
       
    74 
       
    75             val ctrs = maps #ctrs ctr_sugars;
       
    76             val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
       
    77             val subTs = fold add_interesting_subtypes ctr_Ts [];
       
    78           in
       
    79             fold add_nested_types_of subTs (seen @ mutual_Ts)
       
    80           end
       
    81         | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
       
    82             " not associated with new-style datatype (cf. \"datatype_new\")"));
       
    83 
       
    84     val Ts = add_nested_types_of fpT1 [];
       
    85     val bs = map (Binding.name o prefix compatN o base_name_of_typ) Ts;
       
    86     val nn_fp = length fpTs;
       
    87     val nn = length Ts;
       
    88     val get_indices = K [];
       
    89     val fp_sugars0 =
       
    90       if nn = 1 then [fp_sugar0] else map (lfp_sugar_of o fst o dest_Type) Ts;
       
    91     val callssss = pad_and_indexify_calls fp_sugars0 nn [];
       
    92     val has_nested = nn > nn_fp;
       
    93 
       
    94     val ((_, fp_sugars), lthy) =
       
    95       mutualize_fp_sugars false has_nested Least_FP bs Ts get_indices callssss fp_sugars0 lthy;
       
    96 
       
    97     val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
       
    98       fp_sugars;
       
    99     val inducts = conj_dests nn induct;
       
   100 
       
   101     val frozen_Ts = map Type.legacy_freeze_type Ts;
       
   102     val mk_dtyp = dtyp_of_typ frozen_Ts;
       
   103 
       
   104     fun mk_ctr_descr (Const (s, T)) =
       
   105       (s, map mk_dtyp (binder_types (Type.legacy_freeze_type T)));
       
   106     fun mk_typ_descr index (Type (T_name, Ts)) {ctrs, ...} =
       
   107       (index, (T_name, map mk_dtyp Ts, map mk_ctr_descr ctrs));
       
   108 
       
   109     val descr = map3 mk_typ_descr (0 upto nn - 1) frozen_Ts ctr_sugars;
       
   110     val recs = map (fst o dest_Const o co_rec_of) co_iterss;
       
   111     val rec_thms = flat (map co_rec_of iter_thmsss);
       
   112 
       
   113     fun mk_info {T = Type (T_name0, _), index, ...} =
       
   114       let
       
   115         val {casex, exhaust, nchotomy, injects, distincts, case_thms, case_cong, weak_case_cong,
       
   116           split, split_asm, ...} = nth ctr_sugars index;
       
   117       in
       
   118         (T_name0,
       
   119          {index = index, descr = descr, inject = injects, distinct = distincts, induct = induct,
       
   120          inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
       
   121          rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
       
   122          case_cong = case_cong, weak_case_cong = weak_case_cong, split = split,
       
   123          split_asm = split_asm})
       
   124       end;
       
   125 
       
   126     val infos = map mk_info (take nn_fp fp_sugars);
       
   127 
       
   128     val register_and_interpret =
       
   129       Datatype_Data.register infos
       
   130       #> Datatype_Data.interpretation_data (Datatype_Aux.default_config, map fst infos)
       
   131   in
       
   132     Local_Theory.raw_theory register_and_interpret lthy
       
   133   end;
       
   134 
       
   135 val _ =
       
   136   Outer_Syntax.local_theory @{command_spec "datatype_compat"}
       
   137     "register a new-style datatype as an old-style datatype"
       
   138     (Scan.repeat1 Parse.type_const >> datatype_compat_cmd);
       
   139 
       
   140 end;