src/HOL/Tools/BNF/Tools/bnf_lfp_compat.ML
changeset 55058 4e700eb471d4
parent 54950 f00012c20344
equal deleted inserted replaced
55057:6b0fcbeebaba 55058:4e700eb471d4
       
     1 (*  Title:      HOL/BNF/Tools/bnf_lfp_compat.ML
       
     2     Author:     Jasmin Blanchette, TU Muenchen
       
     3     Copyright   2013
       
     4 
       
     5 Compatibility layer with the old datatype package.
       
     6 *)
       
     7 
       
     8 signature BNF_LFP_COMPAT =
       
     9 sig
       
    10   val datatype_new_compat_cmd : string list -> local_theory -> local_theory
       
    11 end;
       
    12 
       
    13 structure BNF_LFP_Compat : BNF_LFP_COMPAT =
       
    14 struct
       
    15 
       
    16 open Ctr_Sugar
       
    17 open BNF_Util
       
    18 open BNF_FP_Util
       
    19 open BNF_FP_Def_Sugar
       
    20 open BNF_FP_N2M_Sugar
       
    21 
       
    22 fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
       
    23   | dtyp_of_typ recTs (T as Type (s, Ts)) =
       
    24     (case find_index (curry (op =) T) recTs of
       
    25       ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
       
    26     | kk => Datatype_Aux.DtRec kk);
       
    27 
       
    28 val compatN = "compat_";
       
    29 
       
    30 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
       
    31 fun datatype_new_compat_cmd raw_fpT_names lthy =
       
    32   let
       
    33     val thy = Proof_Context.theory_of lthy;
       
    34 
       
    35     fun not_datatype s = error (quote s ^ " is not a new-style datatype");
       
    36     fun not_mutually_recursive ss =
       
    37       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
       
    38 
       
    39     val (fpT_names as fpT_name1 :: _) =
       
    40       map (fst o dest_Type o Proof_Context.read_type_name_proper lthy false) raw_fpT_names;
       
    41 
       
    42     fun lfp_sugar_of s =
       
    43       (case fp_sugar_of lthy s of
       
    44         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
       
    45       | _ => not_datatype s);
       
    46 
       
    47     val {ctr_sugars, ...} = lfp_sugar_of fpT_name1;
       
    48     val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) ctr_sugars;
       
    49     val fpT_names' = map (fst o dest_Type) fpTs0;
       
    50 
       
    51     val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
       
    52 
       
    53     val (unsorted_As, _) = lthy |> mk_TFrees (length var_As);
       
    54     val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
       
    55     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
       
    56 
       
    57     fun add_nested_types_of (T as Type (s, _)) seen =
       
    58       if member (op =) seen T then
       
    59         seen
       
    60       else if s = @{type_name fun} then
       
    61         (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
       
    62       else
       
    63         (case try lfp_sugar_of s of
       
    64           SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
       
    65           let
       
    66             val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
       
    67             val substT = Term.typ_subst_TVars rho;
       
    68 
       
    69             val mutual_Ts = map substT mutual_Ts0;
       
    70 
       
    71             fun add_interesting_subtypes (U as Type (_, Us)) =
       
    72                 (case filter (exists_subtype_in mutual_Ts) Us of [] => I
       
    73                 | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
       
    74               | add_interesting_subtypes _ = I;
       
    75 
       
    76             val ctrs = maps #ctrs ctr_sugars;
       
    77             val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
       
    78             val subTs = fold add_interesting_subtypes ctr_Ts [];
       
    79           in
       
    80             fold add_nested_types_of subTs (seen @ mutual_Ts)
       
    81           end
       
    82         | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
       
    83             " not corresponding to new-style datatype (cf. \"datatype_new\")"));
       
    84 
       
    85     val Ts = add_nested_types_of fpT1 [];
       
    86     val b_names = map base_name_of_typ Ts;
       
    87     val compat_b_names = map (prefix compatN) b_names;
       
    88     val compat_bs = map Binding.name compat_b_names;
       
    89     val common_name = compatN ^ mk_common_name b_names;
       
    90     val nn_fp = length fpTs;
       
    91     val nn = length Ts;
       
    92     val get_indices = K [];
       
    93     val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
       
    94     val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
       
    95 
       
    96     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
       
    97       if nn > nn_fp then
       
    98         mutualize_fp_sugars Least_FP compat_bs Ts get_indices callssss fp_sugars0 lthy
       
    99       else
       
   100         ((fp_sugars0, (NONE, NONE)), lthy);
       
   101 
       
   102     val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
       
   103       fp_sugars;
       
   104     val inducts = conj_dests nn induct;
       
   105 
       
   106     val mk_dtyp = dtyp_of_typ Ts;
       
   107 
       
   108     fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp);
       
   109     fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
       
   110       (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
       
   111 
       
   112     val descr = map3 mk_typ_descr (0 upto nn - 1) Ts ctr_sugars;
       
   113     val recs = map (fst o dest_Const o co_rec_of) co_iterss;
       
   114     val rec_thms = flat (map co_rec_of iter_thmsss);
       
   115 
       
   116     fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
       
   117       let
       
   118         val {casex, exhaust, nchotomy, injects, distincts, case_thms, case_cong, weak_case_cong,
       
   119           split, split_asm, ...} = nth ctr_sugars index;
       
   120       in
       
   121         (T_name0,
       
   122          {index = index, descr = descr, inject = injects, distinct = distincts, induct = induct,
       
   123          inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
       
   124          rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
       
   125          case_cong = case_cong, weak_case_cong = weak_case_cong, split = split,
       
   126          split_asm = split_asm})
       
   127       end;
       
   128 
       
   129     val infos = map mk_info (take nn_fp fp_sugars);
       
   130 
       
   131     val all_notes =
       
   132       (case lfp_sugar_thms of
       
   133         NONE => []
       
   134       | SOME ((induct_thms, induct_thm, induct_attrs), (fold_thmss, rec_thmss, _)) =>
       
   135         let
       
   136           val common_notes =
       
   137             (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
       
   138             |> filter_out (null o #2)
       
   139             |> map (fn (thmN, thms, attrs) =>
       
   140               ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
       
   141 
       
   142           val notes =
       
   143             [(foldN, fold_thmss, []),
       
   144              (inductN, map single induct_thms, induct_attrs),
       
   145              (recN, rec_thmss, [])]
       
   146             |> filter_out (null o #2)
       
   147             |> maps (fn (thmN, thmss, attrs) =>
       
   148               if forall null thmss then
       
   149                 []
       
   150               else
       
   151                 map2 (fn b_name => fn thms =>
       
   152                     ((Binding.qualify true b_name (Binding.name thmN), attrs), [(thms, [])]))
       
   153                   compat_b_names thmss);
       
   154         in
       
   155           common_notes @ notes
       
   156         end);
       
   157 
       
   158     val register_interpret =
       
   159       Datatype_Data.register infos
       
   160       #> Datatype_Data.interpretation_data (Datatype_Aux.default_config, map fst infos)
       
   161   in
       
   162     lthy
       
   163     |> Local_Theory.raw_theory register_interpret
       
   164     |> Local_Theory.notes all_notes |> snd
       
   165   end;
       
   166 
       
   167 val _ =
       
   168   Outer_Syntax.local_theory @{command_spec "datatype_new_compat"}
       
   169     "register new-style datatypes as old-style datatypes"
       
   170     (Scan.repeat1 Parse.type_const >> datatype_new_compat_cmd);
       
   171 
       
   172 end;