src/HOL/Tools/BNF/Tools/bnf_lfp_compat.ML
author blanchet
Mon Jan 20 18:24:56 2014 +0100 (2014-01-20)
changeset 55058 4e700eb471d4
parent 54950 src/HOL/BNF/Tools/bnf_lfp_compat.ML@f00012c20344
permissions -rw-r--r--
moved BNF files to 'HOL'
blanchet@53303
     1
(*  Title:      HOL/BNF/Tools/bnf_lfp_compat.ML
blanchet@53303
     2
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@53303
     3
    Copyright   2013
blanchet@53303
     4
blanchet@53303
     5
Compatibility layer with the old datatype package.
blanchet@53303
     6
*)
blanchet@53303
     7
blanchet@53303
     8
signature BNF_LFP_COMPAT =
blanchet@53303
     9
sig
blanchet@53309
    10
  val datatype_new_compat_cmd : string list -> local_theory -> local_theory
blanchet@53303
    11
end;
blanchet@53303
    12
blanchet@53303
    13
structure BNF_LFP_Compat : BNF_LFP_COMPAT =
blanchet@53303
    14
struct
blanchet@53303
    15
blanchet@54006
    16
open Ctr_Sugar
blanchet@53303
    17
open BNF_Util
blanchet@53303
    18
open BNF_FP_Util
blanchet@53303
    19
open BNF_FP_Def_Sugar
blanchet@53303
    20
open BNF_FP_N2M_Sugar
blanchet@53303
    21
blanchet@53303
    22
fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
blanchet@53303
    23
  | dtyp_of_typ recTs (T as Type (s, Ts)) =
blanchet@53303
    24
    (case find_index (curry (op =) T) recTs of
blanchet@53303
    25
      ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
blanchet@53303
    26
    | kk => Datatype_Aux.DtRec kk);
blanchet@53303
    27
blanchet@53303
    28
val compatN = "compat_";
blanchet@53303
    29
blanchet@53303
    30
(* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
blanchet@53309
    31
fun datatype_new_compat_cmd raw_fpT_names lthy =
blanchet@53303
    32
  let
blanchet@53303
    33
    val thy = Proof_Context.theory_of lthy;
blanchet@53303
    34
blanchet@53303
    35
    fun not_datatype s = error (quote s ^ " is not a new-style datatype");
blanchet@53303
    36
    fun not_mutually_recursive ss =
blanchet@53303
    37
      error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
blanchet@53303
    38
blanchet@53303
    39
    val (fpT_names as fpT_name1 :: _) =
blanchet@53303
    40
      map (fst o dest_Type o Proof_Context.read_type_name_proper lthy false) raw_fpT_names;
blanchet@53303
    41
blanchet@53303
    42
    fun lfp_sugar_of s =
blanchet@53303
    43
      (case fp_sugar_of lthy s of
blanchet@53303
    44
        SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
blanchet@53303
    45
      | _ => not_datatype s);
blanchet@53303
    46
blanchet@54950
    47
    val {ctr_sugars, ...} = lfp_sugar_of fpT_name1;
blanchet@54950
    48
    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) ctr_sugars;
blanchet@53303
    49
    val fpT_names' = map (fst o dest_Type) fpTs0;
blanchet@53303
    50
blanchet@53303
    51
    val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
blanchet@53303
    52
blanchet@54950
    53
    val (unsorted_As, _) = lthy |> mk_TFrees (length var_As);
blanchet@54950
    54
    val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
blanchet@53303
    55
    val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
blanchet@53303
    56
blanchet@53303
    57
    fun add_nested_types_of (T as Type (s, _)) seen =
blanchet@54179
    58
      if member (op =) seen T then
blanchet@53303
    59
        seen
blanchet@54179
    60
      else if s = @{type_name fun} then
blanchet@54179
    61
        (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
blanchet@53303
    62
      else
blanchet@53303
    63
        (case try lfp_sugar_of s of
blanchet@53303
    64
          SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
blanchet@53303
    65
          let
blanchet@53303
    66
            val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
blanchet@53303
    67
            val substT = Term.typ_subst_TVars rho;
blanchet@53303
    68
blanchet@53303
    69
            val mutual_Ts = map substT mutual_Ts0;
blanchet@53303
    70
blanchet@53901
    71
            fun add_interesting_subtypes (U as Type (_, Us)) =
blanchet@53303
    72
                (case filter (exists_subtype_in mutual_Ts) Us of [] => I
blanchet@53303
    73
                | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
blanchet@53303
    74
              | add_interesting_subtypes _ = I;
blanchet@53303
    75
blanchet@53303
    76
            val ctrs = maps #ctrs ctr_sugars;
blanchet@53303
    77
            val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
blanchet@53303
    78
            val subTs = fold add_interesting_subtypes ctr_Ts [];
blanchet@53303
    79
          in
blanchet@53303
    80
            fold add_nested_types_of subTs (seen @ mutual_Ts)
blanchet@53303
    81
          end
blanchet@53303
    82
        | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
blanchet@54435
    83
            " not corresponding to new-style datatype (cf. \"datatype_new\")"));
blanchet@53303
    84
blanchet@53303
    85
    val Ts = add_nested_types_of fpT1 [];
blanchet@53746
    86
    val b_names = map base_name_of_typ Ts;
blanchet@53746
    87
    val compat_b_names = map (prefix compatN) b_names;
blanchet@53746
    88
    val compat_bs = map Binding.name compat_b_names;
blanchet@53746
    89
    val common_name = compatN ^ mk_common_name b_names;
blanchet@53303
    90
    val nn_fp = length fpTs;
blanchet@53303
    91
    val nn = length Ts;
blanchet@53303
    92
    val get_indices = K [];
blanchet@54435
    93
    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
blanchet@54267
    94
    val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
blanchet@53303
    95
blanchet@53746
    96
    val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
blanchet@54286
    97
      if nn > nn_fp then
blanchet@54286
    98
        mutualize_fp_sugars Least_FP compat_bs Ts get_indices callssss fp_sugars0 lthy
blanchet@54286
    99
      else
blanchet@54286
   100
        ((fp_sugars0, (NONE, NONE)), lthy);
blanchet@53303
   101
blanchet@53303
   102
    val {ctr_sugars, co_inducts = [induct], co_iterss, co_iter_thmsss = iter_thmsss, ...} :: _ =
blanchet@53303
   103
      fp_sugars;
blanchet@53303
   104
    val inducts = conj_dests nn induct;
blanchet@53303
   105
blanchet@54435
   106
    val mk_dtyp = dtyp_of_typ Ts;
blanchet@53303
   107
blanchet@54435
   108
    fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp);
blanchet@54003
   109
    fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
blanchet@54435
   110
      (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
blanchet@53303
   111
blanchet@54435
   112
    val descr = map3 mk_typ_descr (0 upto nn - 1) Ts ctr_sugars;
blanchet@53303
   113
    val recs = map (fst o dest_Const o co_rec_of) co_iterss;
blanchet@53303
   114
    val rec_thms = flat (map co_rec_of iter_thmsss);
blanchet@53303
   115
blanchet@54003
   116
    fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
blanchet@53303
   117
      let
blanchet@53303
   118
        val {casex, exhaust, nchotomy, injects, distincts, case_thms, case_cong, weak_case_cong,
blanchet@53303
   119
          split, split_asm, ...} = nth ctr_sugars index;
blanchet@53303
   120
      in
blanchet@53303
   121
        (T_name0,
blanchet@53303
   122
         {index = index, descr = descr, inject = injects, distinct = distincts, induct = induct,
blanchet@53303
   123
         inducts = inducts, exhaust = exhaust, nchotomy = nchotomy, rec_names = recs,
blanchet@53303
   124
         rec_rewrites = rec_thms, case_name = fst (dest_Const casex), case_rewrites = case_thms,
blanchet@53303
   125
         case_cong = case_cong, weak_case_cong = weak_case_cong, split = split,
blanchet@53303
   126
         split_asm = split_asm})
blanchet@53303
   127
      end;
blanchet@53303
   128
blanchet@53303
   129
    val infos = map mk_info (take nn_fp fp_sugars);
blanchet@53303
   130
blanchet@53746
   131
    val all_notes =
blanchet@53746
   132
      (case lfp_sugar_thms of
blanchet@53746
   133
        NONE => []
blanchet@53808
   134
      | SOME ((induct_thms, induct_thm, induct_attrs), (fold_thmss, rec_thmss, _)) =>
blanchet@53746
   135
        let
blanchet@53746
   136
          val common_notes =
blanchet@53746
   137
            (if nn > 1 then [(inductN, [induct_thm], induct_attrs)] else [])
blanchet@53746
   138
            |> filter_out (null o #2)
blanchet@53746
   139
            |> map (fn (thmN, thms, attrs) =>
blanchet@53746
   140
              ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
blanchet@53746
   141
blanchet@53746
   142
          val notes =
blanchet@53747
   143
            [(foldN, fold_thmss, []),
blanchet@53746
   144
             (inductN, map single induct_thms, induct_attrs),
blanchet@53747
   145
             (recN, rec_thmss, [])]
blanchet@53746
   146
            |> filter_out (null o #2)
blanchet@53746
   147
            |> maps (fn (thmN, thmss, attrs) =>
blanchet@53746
   148
              if forall null thmss then
blanchet@53746
   149
                []
blanchet@53746
   150
              else
blanchet@53746
   151
                map2 (fn b_name => fn thms =>
blanchet@53746
   152
                    ((Binding.qualify true b_name (Binding.name thmN), attrs), [(thms, [])]))
blanchet@53746
   153
                  compat_b_names thmss);
blanchet@53746
   154
        in
blanchet@53746
   155
          common_notes @ notes
blanchet@53746
   156
        end);
blanchet@53746
   157
blanchet@53746
   158
    val register_interpret =
blanchet@53303
   159
      Datatype_Data.register infos
blanchet@53303
   160
      #> Datatype_Data.interpretation_data (Datatype_Aux.default_config, map fst infos)
blanchet@53303
   161
  in
blanchet@53746
   162
    lthy
blanchet@53746
   163
    |> Local_Theory.raw_theory register_interpret
blanchet@53746
   164
    |> Local_Theory.notes all_notes |> snd
blanchet@53303
   165
  end;
blanchet@53303
   166
blanchet@53303
   167
val _ =
blanchet@53309
   168
  Outer_Syntax.local_theory @{command_spec "datatype_new_compat"}
blanchet@54025
   169
    "register new-style datatypes as old-style datatypes"
blanchet@53309
   170
    (Scan.repeat1 Parse.type_const >> datatype_new_compat_cmd);
blanchet@53303
   171
blanchet@53303
   172
end;