src/HOL/Tools/BNF/bnf_lfp_compat.ML
changeset 56455 1ff66e72628b
parent 56453 00548d372f02
child 56484 c451cf8b29c8
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Tue Apr 08 18:06:21 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Tue Apr 08 18:16:47 2014 +0200
@@ -40,7 +40,7 @@
   end
 
 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
-fun datatype_compat_cmd raw_fpT_names lthy =
+fun datatype_compat_cmd raw_fpT_names0 lthy =
   let
     val thy = Proof_Context.theory_of lthy;
 
@@ -48,24 +48,23 @@
     fun not_mutually_recursive ss =
       error ("{" ^ commas ss ^ "} is not a complete set of mutually recursive new-style datatypes");
 
-    val fpT_names =
+    val fpT_names0 =
       map (fst o dest_Type o Proof_Context.read_type_name {proper = true, strict = false} lthy)
-        raw_fpT_names;
+        raw_fpT_names0;
 
     fun lfp_sugar_of s =
       (case fp_sugar_of lthy s of
         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
       | _ => not_datatype s);
 
-    val fp_ctr_sugars = map (#ctr_sugar o lfp_sugar_of) fpT_names;
-    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) fp_ctr_sugars;
-    val fpT_names' = map (fst o dest_Type) fpTs0;
+    val fpTs0 as Type (_, var_As) :: _ = #Ts (#fp_res (lfp_sugar_of (hd fpT_names0)));
+    val fpT_names = map (fst o dest_Type) fpTs0;
 
-    val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
+    val _ = eq_set (op =) (fpT_names0, fpT_names) orelse not_mutually_recursive fpT_names0;
 
     val (As_names, _) = lthy |> Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As);
     val As = map2 (fn s => fn TVar (_, S) => TFree (s, S)) As_names var_As;
-    val fpTs = map (fn s => Type (s, As)) fpT_names';
+    val fpTs = map (fn s => Type (s, As)) fpT_names;
 
     val nn_fp = length fpTs;
 
@@ -75,6 +74,7 @@
     fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
       (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
 
+    val fp_ctr_sugars = map (#ctr_sugar o lfp_sugar_of) fpT_names;
     val orig_descr = map3 mk_typ_descr (0 upto nn_fp - 1) fpTs fp_ctr_sugars;
     val all_infos = Datatype_Data.get_all thy;
     val (orig_descr' :: nested_descrs, _) =