src/HOL/Tools/BNF/bnf_lfp_compat.ML
changeset 58125 a2ba381607fb
parent 58123 62765d39539f
child 58126 3831312eb476
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 01 16:17:47 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Mon Sep 01 16:17:47 2014 +0200
@@ -10,20 +10,24 @@
 
 signature BNF_LFP_COMPAT =
 sig
-  datatype nesting_mode = Keep_Nesting | Unfold_Nesting_if_Possible | Always_Unfold_Nesting
+  datatype nesting_preference = Keep_Nesting | Unfold_Nesting
 
-  val get_all: theory -> nesting_mode -> Old_Datatype_Aux.info Symtab.table
-  val get_info: theory -> nesting_mode -> string -> Old_Datatype_Aux.info option
-  val the_info: theory -> nesting_mode -> string -> Old_Datatype_Aux.info
-  val the_spec: theory -> nesting_mode -> string -> (string * sort) list * (string * typ list) list
-  val the_descr: theory -> nesting_mode -> string list ->
+  val get_all: theory -> nesting_preference -> Old_Datatype_Aux.info Symtab.table
+  val get_info: theory -> nesting_preference -> string -> Old_Datatype_Aux.info option
+  val the_info: theory -> nesting_preference -> string -> Old_Datatype_Aux.info
+  val the_spec: theory -> nesting_preference -> string ->
+    (string * sort) list * (string * typ list) list
+  val the_descr: theory -> nesting_preference -> string list ->
     Old_Datatype_Aux.descr * (string * sort) list * string list * string
     * (string list * string list) * (typ list * typ list)
-  val get_constrs: theory -> nesting_mode -> string -> (string * typ) list option
-  val interpretation: nesting_mode ->
+  val get_constrs: theory -> nesting_preference -> string -> (string * typ) list option
+  val interpretation: nesting_preference ->
     (Old_Datatype_Aux.config -> string list -> theory -> theory) -> theory -> theory
+  val datatype_compat: string list -> local_theory -> local_theory
+  val datatype_compat_global: string list -> theory -> theory
   val datatype_compat_cmd: string list -> local_theory -> local_theory
-  val add_datatype: nesting_mode -> Old_Datatype_Aux.spec list -> theory -> string list * theory
+  val add_datatype: nesting_preference -> Old_Datatype_Aux.spec list -> theory ->
+    string list * theory
 end;
 
 structure BNF_LFP_Compat : BNF_LFP_COMPAT =
@@ -38,7 +42,7 @@
 
 val compatN = "compat_";
 
-datatype nesting_mode = Keep_Nesting | Unfold_Nesting_if_Possible | Always_Unfold_Nesting;
+datatype nesting_preference = Keep_Nesting | Unfold_Nesting;
 
 fun reindex_desc desc =
   let
@@ -57,8 +61,8 @@
       map (fn (_, (s, Ds, sDss)) => (s, map perm_dtyp Ds, map (apsnd (map perm_dtyp)) sDss)) desc
   end;
 
-fun mk_infos_of_mutually_recursive_new_datatypes nesting_mode need_co_inducts_recs check_names
-    raw_fpT_names0 lthy =
+fun mk_infos_of_mutually_recursive_new_datatypes nesting_pref need_co_inducts_recs check_names
+    fpT_names0 lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -66,10 +70,6 @@
     fun not_mutually_recursive ss =
       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
 
-    val fpT_names0 =
-      map (fst o dest_Type o Proof_Context.read_type_name {proper = true, strict = false} lthy)
-        raw_fpT_names0;
-
     fun lfp_sugar_of s =
       (case fp_sugar_of lthy s of
         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
@@ -96,10 +96,8 @@
     val orig_descr = map3 mk_typ_descr (0 upto nn_fp - 1) fpTs fp_ctr_sugars;
     val all_infos = Old_Datatype_Data.get_all thy;
     val (orig_descr' :: nested_descrs) =
-      if nesting_mode = Keep_Nesting then
-        [orig_descr]
-      else
-        fst (Old_Datatype_Aux.unfold_datatypes lthy orig_descr all_infos orig_descr nn_fp);
+      if nesting_pref = Keep_Nesting then [orig_descr]
+      else fst (Old_Datatype_Aux.unfold_datatypes lthy orig_descr all_infos orig_descr nn_fp);
 
     fun cliquify_descr [] = []
       | cliquify_descr [entry] = [[entry]]
@@ -172,56 +170,61 @@
     (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy')
   end;
 
-fun infos_of_new_datatype_mutual_cluster lthy nesting_mode raw_fpt_names01 =
-  mk_infos_of_mutually_recursive_new_datatypes nesting_mode false subset [raw_fpt_names01] lthy
-  |> #5;
+fun infos_of_new_datatype_mutual_cluster lthy nesting_pref fpT_name =
+  let
+    fun infos_of nesting_pref =
+      #5 (mk_infos_of_mutually_recursive_new_datatypes nesting_pref false subset [fpT_name] lthy);
+  in
+    infos_of nesting_pref
+    handle ERROR _ => if nesting_pref = Unfold_Nesting then infos_of Keep_Nesting else []
+  end;
 
-fun get_all thy nesting_mode =
+fun get_all thy nesting_pref =
   let
     val lthy = Named_Target.theory_init thy;
     val old_info_tab = Old_Datatype_Data.get_all thy;
     val new_T_names = BNF_FP_Def_Sugar.fp_sugars_of_global thy
       |> map_filter (try (fn {T = Type (s, _), fp_res_index = 0, ...} => s));
-    val new_infos = maps (infos_of_new_datatype_mutual_cluster lthy nesting_mode) new_T_names;
+    val new_infos = maps (infos_of_new_datatype_mutual_cluster lthy nesting_pref) new_T_names;
   in
-    fold (if nesting_mode = Keep_Nesting then Symtab.update else Symtab.default) new_infos
+    fold (if nesting_pref = Keep_Nesting then Symtab.update else Symtab.default) new_infos
       old_info_tab
   end;
 
-fun get_one get_old get_new thy nesting_mode x =
+fun get_one get_old get_new thy nesting_pref x =
   let
     val (get_fst, get_snd) =
-      (get_old thy, get_new thy nesting_mode) |> nesting_mode = Keep_Nesting ? swap
+      (get_old thy, get_new thy nesting_pref) |> nesting_pref = Keep_Nesting ? swap
   in
     (case get_fst x of NONE => get_snd x | res => res)
   end;
 
-fun get_info_of_new_datatype thy nesting_mode T_name =
+fun get_info_of_new_datatype thy nesting_pref T_name =
   let val lthy = Named_Target.theory_init thy in
-    AList.lookup (op =) (infos_of_new_datatype_mutual_cluster lthy nesting_mode T_name) T_name
+    AList.lookup (op =) (infos_of_new_datatype_mutual_cluster lthy nesting_pref T_name) T_name
   end;
 
 val get_info = get_one Old_Datatype_Data.get_info get_info_of_new_datatype;
 
-fun the_info thy nesting_mode T_name =
-  (case get_info thy nesting_mode T_name of
+fun the_info thy nesting_pref T_name =
+  (case get_info thy nesting_pref T_name of
     SOME info => info
   | NONE => error ("Unknown datatype " ^ quote T_name));
 
-fun the_spec thy nesting_mode T_name =
+fun the_spec thy nesting_pref T_name =
   let
-    val {descr, index, ...} = the_info thy nesting_mode T_name;
+    val {descr, index, ...} = the_info thy nesting_pref T_name;
     val (_, Ds, ctrs0) = the (AList.lookup (op =) descr index);
     val Ts = map Old_Datatype_Aux.dest_DtTFree Ds;
     val ctrs = map (apsnd (map (Old_Datatype_Aux.typ_of_dtyp descr))) ctrs0;
   in (Ts, ctrs) end;
 
-fun the_descr thy nesting_mode (T_names0 as T_name01 :: _) =
+fun the_descr thy nesting_pref (T_names0 as T_name01 :: _) =
   let
     fun not_mutually_recursive ss =
       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive datatypes");
 
-    val info = the_info thy nesting_mode T_name01;
+    val info = the_info thy nesting_pref T_name01;
     val descr = #descr info;
 
     val (_, Ds, _) = the (AList.lookup (op =) descr (#index info));
@@ -248,8 +251,8 @@
     (descr, vs, T_names, prefix, (names, auxnames), (Ts, Us))
   end;
 
-fun get_constrs thy nesting_mode T_name =
-  try (the_spec thy nesting_mode) T_name
+fun get_constrs thy nesting_pref T_name =
+  try (the_spec thy nesting_pref) T_name
   |> Option.map (fn (tfrees, ctrs) =>
     let
       fun varify_tfree (s, S) = TVar ((s, 0), S);
@@ -263,32 +266,31 @@
       map (apsnd mk_ctr_typ) ctrs
     end);
 
-fun old_interpretation_of nesting_mode f config T_names thy =
-  if nesting_mode <> Keep_Nesting orelse exists (is_none o fp_sugar_of_global thy) T_names then
+fun old_interpretation_of nesting_pref f config T_names thy =
+  if nesting_pref = Unfold_Nesting orelse exists (is_none o fp_sugar_of_global thy) T_names then
     f config T_names thy
   else
     thy;
 
-fun new_interpretation_of nesting_mode f fp_sugars thy =
+fun new_interpretation_of nesting_pref f fp_sugars thy =
   let val T_names = map (fst o dest_Type o #T) fp_sugars in
-    if nesting_mode = Keep_Nesting orelse
+    if nesting_pref = Keep_Nesting orelse
         exists (is_none o Old_Datatype_Data.get_info thy) T_names then
       f Old_Datatype_Aux.default_config T_names thy
     else
       thy
   end;
 
-fun interpretation nesting_mode f =
-  Old_Datatype_Data.interpretation (old_interpretation_of nesting_mode f)
-  #> fp_sugar_interpretation (new_interpretation_of nesting_mode f);
+fun interpretation nesting_pref f =
+  Old_Datatype_Data.interpretation (old_interpretation_of nesting_pref f)
+  #> fp_sugar_interpretation (new_interpretation_of nesting_pref f);
 
 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
 
-fun datatype_compat_cmd fpT_names lthy =
+fun datatype_compat fpT_names lthy =
   let
     val (nn, b_names, compat_b_names, lfp_sugar_thms, infos, lthy') =
-      mk_infos_of_mutually_recursive_new_datatypes Unfold_Nesting_if_Possible true eq_set fpT_names
-        lthy;
+      mk_infos_of_mutually_recursive_new_datatypes Unfold_Nesting true eq_set fpT_names lthy;
 
     val all_notes =
       (case lfp_sugar_thms of
@@ -328,7 +330,21 @@
     |> snd
   end;
 
-fun add_datatype nesting_mode old_specs thy =
+fun datatype_compat_global fpT_names =
+  Named_Target.theory_init
+  #> datatype_compat fpT_names
+  #> Named_Target.exit;
+
+fun datatype_compat_cmd raw_fpT_names lthy =
+  let
+    val fpT_names =
+      map (fst o dest_Type o Proof_Context.read_type_name {proper = true, strict = false} lthy)
+        raw_fpT_names;
+  in
+    datatype_compat fpT_names lthy
+  end;
+
+fun add_datatype nesting_pref old_specs thy =
   let
     val fpT_names = map (Sign.full_name thy o #1 o fst) old_specs;
 
@@ -345,8 +361,8 @@
      thy
      |> Named_Target.theory_init
      |> co_datatypes Least_FP construct_lfp ((false, false), new_specs)
-     |> nesting_mode <> Keep_Nesting ? datatype_compat_cmd fpT_names
-     |> Named_Target.exit)
+     |> Named_Target.exit
+     |> nesting_pref = Unfold_Nesting ? perhaps (try (datatype_compat_global fpT_names)))
   end;
 
 val _ =