allow arguments to 'datatype_compat' in disorder
authorblanchet
Tue, 08 Apr 2014 18:16:47 +0200
changeset 56455 1ff66e72628b
parent 56454 e9e82384e5a1
child 56456 39281b3e4fac
allow arguments to 'datatype_compat' in disorder
src/HOL/BNF_Examples/Compat.thy
src/HOL/Tools/BNF/bnf_lfp_compat.ML
--- a/src/HOL/BNF_Examples/Compat.thy	Tue Apr 08 18:06:21 2014 +0200
+++ b/src/HOL/BNF_Examples/Compat.thy	Tue Apr 08 18:16:47 2014 +0200
@@ -52,14 +52,12 @@
 | "f_mylist (MyCons _ xs) = Suc (f_mylist xs)"
 
 datatype_new foo' = FooNil | FooCons bar' foo' and bar' = Bar
-(* FIXME
 datatype_compat bar' foo'
 
 fun f_foo and f_bar where
   "f_foo FooNil = 0"
 | "f_foo (FooCons bar foo) = Suc (f_foo foo) + f_bar bar"
 | "f_bar Bar = Suc 0"
-*)
 
 locale opt begin
 
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Tue Apr 08 18:06:21 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Tue Apr 08 18:16:47 2014 +0200
@@ -40,7 +40,7 @@
   end
 
 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
-fun datatype_compat_cmd raw_fpT_names lthy =
+fun datatype_compat_cmd raw_fpT_names0 lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -48,24 +48,23 @@
     fun not_mutually_recursive ss =
       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
 
-    val fpT_names =
+    val fpT_names0 =
       map (fst o dest_Type o Proof_Context.read_type_name {proper = true, strict = false} lthy)
-        raw_fpT_names;
+        raw_fpT_names0;
 
     fun lfp_sugar_of s =
       (case fp_sugar_of lthy s of
         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
       | _ => not_datatype s);
 
-    val fp_ctr_sugars = map (#ctr_sugar o lfp_sugar_of) fpT_names;
-    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) fp_ctr_sugars;
-    val fpT_names' = map (fst o dest_Type) fpTs0;
+    val fpTs0 as Type (_, var_As) :: _ = #Ts (#fp_res (lfp_sugar_of (hd fpT_names0)));
+    val fpT_names = map (fst o dest_Type) fpTs0;
 
-    val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
+    val _ = eq_set (op =) (fpT_names0, fpT_names) orelse not_mutually_recursive fpT_names0;
 
     val (As_names, _) = lthy |> Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As);
     val As = map2 (fn s => fn TVar (_, S) => TFree (s, S)) As_names var_As;
-    val fpTs = map (fn s => Type (s, As)) fpT_names';
+    val fpTs = map (fn s => Type (s, As)) fpT_names;
 
     val nn_fp = length fpTs;
 
@@ -75,6 +74,7 @@
     fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
       (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
 
+    val fp_ctr_sugars = map (#ctr_sugar o lfp_sugar_of) fpT_names;
     val orig_descr = map3 mk_typ_descr (0 upto nn_fp - 1) fpTs fp_ctr_sugars;
     val all_infos = Datatype_Data.get_all thy;
     val (orig_descr' :: nested_descrs, _) =