src/HOL/Tools/BNF/bnf_lfp_compat.ML
changeset 55485 bdfb607543f4
parent 55481 a8b83356e869
child 55486 8609527278f2
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 16:22:09 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 17:18:28 2014 +0100
@@ -46,8 +46,8 @@
         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
       | _ => not_datatype s);
 
-    val {ctr_sugars, ...} = lfp_sugar_of fpT_name1;
-    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) ctr_sugars;
+    val {ctr_sugars = fp_ctr_sugars, ...} = lfp_sugar_of fpT_name1;
+    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) fp_ctr_sugars;
     val fpT_names' = map (fst o dest_Type) fpTs0;
 
     val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
@@ -71,7 +71,7 @@
               if s = @{type_name fun} then
                 if exists_subtype_in mutual_Ts U then
                   (warning "Incomplete support for recursion through functions -- \
-                     \'primrec' will fail";
+                     \the old 'primrec' will fail";
                    Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
                 else
                   ([], accum)
@@ -99,11 +99,23 @@
       | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
           " not corresponding to new-style datatype (cf. \"datatype_new\")"));
 
-    fun get_indices (Bound kk) = [kk];
+    fun get_indices (Var ((_, kk), _)) = [kk];
 
     val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
     val Ts = map fst Tkkssss;
-    val callssss = map (map (map (map Bound)) o snd) Tkkssss;
+
+    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
+    val ctrss0 = map (#ctrs o of_fp_sugar #ctr_sugars) fp_sugars0;
+    val ctr_Tsss0 = map (map (binder_types o fastype_of)) ctrss0;
+
+    fun apply_comps n kk =
+      mk_partial_compN n (replicate n @{typ unit} ---> @{typ unit})
+        (Var ((Name.uu, kk), @{typ "unit => unit"}));
+
+    val callssss =
+      map2 (map2 (map2 (fn kks => fn ctr_T =>
+          map (apply_comps (num_binder_types ctr_T)) kks)) o snd)
+        Tkkssss ctr_Tsss0;
 
     val b_names = Name.variant_list [] (map base_name_of_typ Ts);
     val compat_b_names = map (prefix compatN) b_names;
@@ -111,7 +123,6 @@
     val common_name = compatN ^ mk_common_name b_names;
     val nn_fp = length fpTs;
     val nn = length Ts;
-    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
 
     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
       if nn > nn_fp then