better handling of recursion through functions
authorblanchet
Fri, 14 Feb 2014 17:18:28 +0100
changeset 55485 bdfb607543f4
parent 55484 9deb5066508f
child 55486 8609527278f2
better handling of recursion through functions
src/HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML
src/HOL/Tools/BNF/bnf_fp_util.ML
src/HOL/Tools/BNF/bnf_lfp_compat.ML
--- a/src/HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML	Fri Feb 14 16:22:09 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_rec_sugar_util.ML	Fri Feb 14 17:18:28 2014 +0100
@@ -17,11 +17,6 @@
 
   val drop_all: term -> term
 
-  val mk_partial_compN: int -> typ -> term -> term
-  val mk_partial_comp: typ -> typ -> term -> term
-  val mk_compN: int -> typ list -> term * term -> term
-  val mk_comp: typ list -> term * term -> term
-
   val get_indices: ((binding * typ) * 'a) list -> term -> int list
 end;
 
@@ -42,24 +37,6 @@
   subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev,
     strip_qnt_body @{const_name all} t);
 
-fun mk_partial_comp gT fT g =
-  let val T = domain_type fT --> range_type gT in
-    Const (@{const_name Fun.comp}, gT --> fT --> T) $ g
-  end;
-
-fun mk_partial_compN 0 _ g = g
-  | mk_partial_compN n fT g =
-    let val g' = mk_partial_compN (n - 1) (range_type fT) g in
-      mk_partial_comp (fastype_of g') fT g'
-    end;
-
-fun mk_compN n bound_Ts (g, f) =
-  let val typof = curry fastype_of1 bound_Ts in
-    mk_partial_compN n (typof f) g $ f
-  end;
-
-val mk_comp = mk_compN 1;
-
 fun get_indices fixes t = map (fst #>> Binding.name_of #> Free) fixes
   |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
   |> map_filter I;
--- a/src/HOL/Tools/BNF/bnf_fp_util.ML	Fri Feb 14 16:22:09 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_util.ML	Fri Feb 14 17:18:28 2014 +0100
@@ -139,6 +139,11 @@
 
   val mk_proj: typ -> int -> int -> term
 
+  val mk_partial_compN: int -> typ -> term -> term
+  val mk_partial_comp: typ -> typ -> term -> term
+  val mk_compN: int -> typ list -> term * term -> term
+  val mk_comp: typ list -> term * term -> term
+
   val mk_convol: term * term -> term
 
   val Inl_const: typ -> typ -> term
@@ -377,6 +382,24 @@
     fold_rev (fn T => fn t => Abs (Name.uu, T, t)) binders (Bound (n - k - 1))
   end;
 
+fun mk_partial_comp gT fT g =
+  let val T = domain_type fT --> range_type gT in
+    Const (@{const_name Fun.comp}, gT --> fT --> T) $ g
+  end;
+
+fun mk_partial_compN 0 _ g = g
+  | mk_partial_compN n fT g =
+    let val g' = mk_partial_compN (n - 1) (range_type fT) g in
+      mk_partial_comp (fastype_of g') fT g'
+    end;
+
+fun mk_compN n bound_Ts (g, f) =
+  let val typof = curry fastype_of1 bound_Ts in
+    mk_partial_compN n (typof f) g $ f
+  end;
+
+val mk_comp = mk_compN 1;
+
 fun mk_convol (f, g) =
   let
     val (fU, fTU) = `range_type (fastype_of f);
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 16:22:09 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 17:18:28 2014 +0100
@@ -46,8 +46,8 @@
         SOME (fp_sugar as {fp = Least_FP, ...}) => fp_sugar
       | _ => not_datatype s);
 
-    val {ctr_sugars, ...} = lfp_sugar_of fpT_name1;
-    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) ctr_sugars;
+    val {ctr_sugars = fp_ctr_sugars, ...} = lfp_sugar_of fpT_name1;
+    val fpTs0 as Type (_, var_As) :: _ = map (body_type o fastype_of o hd o #ctrs) fp_ctr_sugars;
     val fpT_names' = map (fst o dest_Type) fpTs0;
 
     val _ = eq_set (op =) (fpT_names, fpT_names') orelse not_mutually_recursive fpT_names;
@@ -71,7 +71,7 @@
               if s = @{type_name fun} then
                 if exists_subtype_in mutual_Ts U then
                   (warning "Incomplete support for recursion through functions -- \
-                     \'primrec' will fail";
+                     \the old 'primrec' will fail";
                    Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
                 else
                   ([], accum)
@@ -99,11 +99,23 @@
       | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
           " not corresponding to new-style datatype (cf. \"datatype_new\")"));
 
-    fun get_indices (Bound kk) = [kk];
+    fun get_indices (Var ((_, kk), _)) = [kk];
 
     val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
     val Ts = map fst Tkkssss;
-    val callssss = map (map (map (map Bound)) o snd) Tkkssss;
+
+    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
+    val ctrss0 = map (#ctrs o of_fp_sugar #ctr_sugars) fp_sugars0;
+    val ctr_Tsss0 = map (map (binder_types o fastype_of)) ctrss0;
+
+    fun apply_comps n kk =
+      mk_partial_compN n (replicate n @{typ unit} ---> @{typ unit})
+        (Var ((Name.uu, kk), @{typ "unit => unit"}));
+
+    val callssss =
+      map2 (map2 (map2 (fn kks => fn ctr_T =>
+          map (apply_comps (num_binder_types ctr_T)) kks)) o snd)
+        Tkkssss ctr_Tsss0;
 
     val b_names = Name.variant_list [] (map base_name_of_typ Ts);
     val compat_b_names = map (prefix compatN) b_names;
@@ -111,7 +123,6 @@
     val common_name = compatN ^ mk_common_name b_names;
     val nn_fp = length fpTs;
     val nn = length Ts;
-    val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
 
     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
       if nn > nn_fp then