src/HOL/Tools/BNF/Tools/bnf_lfp_compat.ML
author blanchet
Mon Jan 20 18:24:56 2014 +0100 (2014-01-20)
changeset 55058 4e700eb471d4
parent 54950 src/HOL/BNF/Tools/bnf_lfp_compat.ML@f00012c20344
permissions -rw-r--r--
moved BNF files to 'HOL'
     1 (*  Title:      HOL/BNF/Tools/bnf_lfp_compat.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2013
     4 
     5 Compatibility layer with the old datatype package.
     6 *)
     7 
     8 signature BNF_LFP_COMPAT =
     9 sig
    10   val datatype_new_compat_cmd : string list -> local_theory -> local_theory
    11 end;
    12 
    13 structure BNF_LFP_Compat : BNF_LFP_COMPAT =
    14 struct
    15 
    16 open Ctr_Sugar
    17 open BNF_Util
    18 open BNF_FP_Util
    19 open BNF_FP_Def_Sugar
    20 open BNF_FP_N2M_Sugar
    21 
    22 fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
    23   | dtyp_of_typ recTs (T as Type (s, Ts)) =
    24     (case find_index (curry (op =) T) recTs of
    25       ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
    26     | kk => Datatype_Aux.DtRec kk);
    27 
    28 val compatN = "compat_";
    29 
    30 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
    31 fun datatype_new_compat_cmd raw_fpT_names lthy =
    32   let
    33     val thy = Proof_Context.theory_of lthy;
    34 
    35     fun not_datatype s = error (quote s ^ " is not a new-style datatype");
    36     fun not_mutually_recursive ss =
    37       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
    38 
    39     val (fpT_names as fpT_name1 :: _) =
    40       map (fst o dest_Type o Proof_Context.read_type_name_proper lthy false) raw_fpT_names;
    41 
    42     fun lfp_sugar_of s =
    43       (case fp_sugar_of lthy s of
    44         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
    45       | _ => not_datatype s);
    46 
    47     val {ctr_sugars, ...} = lfp_sugar_of fpT_name1;
    48     val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) ctr_sugars;
    49     val fpT_names' = map (fst o dest_Type) fpTs0;
    50 
    51     val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
    52 
    53     val (unsorted_As, _) = lthy |> mk_TFrees (length var_As);
    54     val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
    55     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
    56 
    57     fun add_nested_types_of (T as Type (s, _)) seen =
    58       if member (op =) seen T then
    59         seen
    60       else if s = @{type_name fun} then
    61         (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
    62       else
    63         (case try lfp_sugar_of s of
    64           SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
    65           let
    66             val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
    67             val substT = Term.typ_subst_TVars rho;
    68 
    69             val mutual_Ts = map substT mutual_Ts0;
    70 
    71             fun add_interesting_subtypes (U as Type (_, Us)) =
    72                 (case filter (exists_subtype_in mutual_Ts) Us of [] => I
    73                 | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
    74               | add_interesting_subtypes _ = I;
    75 
    76             val ctrs = maps #ctrs ctr_sugars;
    77             val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
    78             val subTs = fold add_interesting_subtypes ctr_Ts [];
    79           in
    80             fold add_nested_types_of subTs (seen @ mutual_Ts)
    81           end
    82         | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
    83             " not corresponding to new-style datatype (cf. \"datatype_new\")"));
    84 
    85     val Ts = add_nested_types_of fpT1 [];
    86     val b_names = map base_name_of_typ Ts;
    87     val compat_b_names = map (prefix compatN) b_names;
    88     val compat_bs = map Binding.name compat_b_names;
    89     val common_name = compatN ^ mk_common_name b_names;
    90     val nn_fp = length fpTs;
    91     val nn = length Ts;
    92     val get_indices = K [];
    93     val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
    94     val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
    95 
    96     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
    97       if nn > nn_fp then
    98         mutualize_fp_sugars Least_FP compat_bs Ts get_indices callssss fp_sugars0 lthy
    99       else
   100         ((fp_sugars0, (NONE, NONE)), lthy);
   101 
   102     val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
   103       fp_sugars;
   104     val inducts = conj_dests nn induct;
   105 
   106     val mk_dtyp = dtyp_of_typ Ts;
   107 
   108     fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp);
   109     fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
   110       (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
   111 
   112     val descr = map3 mk_typ_descr (0 upto nn - 1) Ts ctr_sugars;
   113     val recs = map (fst o dest_Const o co_rec_of) co_iterss;
   114     val rec_thms = flat (map co_rec_of iter_thmsss);
   115 
   116     fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
   117       let
   118         val {casex, exhaust, nchotomy, injects, distincts, case_thms, case_cong, weak_case_cong,
   119           split, split_asm, ...} = nth ctr_sugars index;
   120       in
   121         (T_name0,
   122          {index = index, descr = descr, inject = injects, distinct = distincts, induct = induct,
   123          inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
   124          rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
   125          case_cong = case_cong, weak_case_cong = weak_case_cong, split = split,
   126          split_asm = split_asm})
   127       end;
   128 
   129     val infos = map mk_info (take nn_fp fp_sugars);
   130 
   131     val all_notes =
   132       (case lfp_sugar_thms of
   133         NONE => []
   134       | SOME ((induct_thms, induct_thm, induct_attrs), (fold_thmss, rec_thmss, _)) =>
   135         let
   136           val common_notes =
   137             (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
   138             |> filter_out (null o #2)
   139             |> map (fn (thmN, thms, attrs) =>
   140               ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   141 
   142           val notes =
   143             [(foldN, fold_thmss, []),
   144              (inductN, map single induct_thms, induct_attrs),
   145              (recN, rec_thmss, [])]
   146             |> filter_out (null o #2)
   147             |> maps (fn (thmN, thmss, attrs) =>
   148               if forall null thmss then
   149                 []
   150               else
   151                 map2 (fn b_name => fn thms =>
   152                     ((Binding.qualify true b_name (Binding.name thmN), attrs), [(thms, [])]))
   153                   compat_b_names thmss);
   154         in
   155           common_notes @ notes
   156         end);
   157 
   158     val register_interpret =
   159       Datatype_Data.register infos
   160       #> Datatype_Data.interpretation_data (Datatype_Aux.default_config, map fst infos)
   161   in
   162     lthy
   163     |> Local_Theory.raw_theory register_interpret
   164     |> Local_Theory.notes all_notes |> snd
   165   end;
   166 
   167 val _ =
   168   Outer_Syntax.local_theory @{command_spec "datatype_new_compat"}
   169     "register new-style datatypes as old-style datatypes"
   170     (Scan.repeat1 Parse.type_const >> datatype_new_compat_cmd);
   171 
   172 end;