handle applied ctor arguments gracefully when computing 'callssss' (for recursion through functions)
authorblanchet
Thu, 24 Oct 2013 19:43:21 +0200
changeset 54202 0a06b51ffa56
parent 54201 334a29265b2d
child 54203 4d3a481fc48e
handle applied ctor arguments gracefully when computing 'callssss' (for recursion through functions)
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Oct 24 18:50:59 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Oct 24 19:43:21 2013 +0200
@@ -287,27 +287,35 @@
       bs mxs
   end;
 
-fun find_rec_calls has_call (eqn_data : eqn_data) =
+fun massage_comp ctxt has_call bound_Ts t =
+  massage_nested_corec_call ctxt has_call (K (K (K I))) bound_Ts (fastype_of1 (bound_Ts, t)) t;
+
+fun find_rec_calls ctxt has_call (eqn_data : eqn_data) =
   let
-    fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
-      | find (t as _ $ _) ctr_arg =
+    fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
+      | find bound_Ts (t as _ $ _) ctr_arg =
         let
+          val typof = curry fastype_of1 bound_Ts;
           val (f', args') = strip_comb t;
-          val n = find_index (equal ctr_arg) args';
+          val n = find_index (equal ctr_arg o head_of) args';
         in
           if n < 0 then
-            find f' ctr_arg @ maps (fn x => find x ctr_arg) args'
+            find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
           else
-            let val (f, args) = chop n args' |>> curry list_comb f' in
+            let
+              val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
+              val (arg_head, arg_args) = Term.strip_comb arg;
+            in
               if has_call f then
-                f :: maps (fn x => find x ctr_arg) args
+                mk_partial_compN (length arg_args) (typof f) (typof arg_head) f ::
+                maps (fn x => find bound_Ts x ctr_arg) args
               else
-                find f ctr_arg @ maps (fn x => find x ctr_arg) args
+                find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
             end
         end
-      | find _ _ = [];
+      | find _ _ _ = [];
   in
-    map (find (#rhs_term eqn_data)) (#ctr_args eqn_data)
+    map (find [] (#rhs_term eqn_data)) (#ctr_args eqn_data)
     |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
   end;
 
@@ -327,7 +335,7 @@
     val res_Ts = map (#res_type o hd) funs_data;
     val callssss = funs_data
       |> map (partition_eq ((op =) o pairself #ctr))
-      |> map (maps (map_filter (find_rec_calls has_call)));
+      |> map (maps (map_filter (find_rec_calls lthy has_call)));
 
     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
       rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Thu Oct 24 18:50:59 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Thu Oct 24 19:43:21 2013 +0200
@@ -62,6 +62,8 @@
   val s_disjs: term list -> term
   val s_dnf: term list list -> term list
 
+  val mk_partial_compN: int -> typ -> typ -> term -> term
+
   val massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
     typ list -> term -> term -> term -> term
   val unfold_let: term -> term
@@ -212,6 +214,22 @@
       |> (fn [cs] => cs | css => [s_disjs (map s_conjs css)])
   end;
 
+fun mk_partial_comp gT fT g =
+  let val T = domain_type fT --> range_type gT in
+    Const (@{const_name Fun.comp}, gT --> fT --> T) $ g
+  end;
+
+fun mk_partial_compN 0 _ _ g = g
+  | mk_partial_compN n gT fT g =
+    let val g' = mk_partial_compN (n - 1) gT (range_type fT) g in
+      mk_partial_comp (fastype_of g') fT g'
+    end;
+
+fun mk_compN bound_Ts n (g, f) =
+  let val typof = curry fastype_of1 bound_Ts in
+    mk_partial_compN n (typof g) (typof f) g $ f
+  end;
+
 fun factor_out_types ctxt massage destU U T =
   (case try destU U of
     SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt
@@ -225,22 +243,6 @@
     permute_like (op aconv) flat_fs fs flat_fs'
   end;
 
-fun mk_partial_comp gT fT g =
-  let val T = domain_type fT --> range_type gT in
-    Const (@{const_name Fun.comp}, gT --> fT --> T) $ g
-  end;
-
-fun mk_compN' 0 _ _ g = g
-  | mk_compN' n gT fT g =
-    let val g' = mk_compN' (n - 1) gT (range_type fT) g in
-      mk_partial_comp (fastype_of g') fT g'
-    end;
-
-fun mk_compN bound_Ts n (g, f) =
-  let val typof = curry fastype_of1 bound_Ts in
-    mk_compN' n (typof g) (typof f) g $ f
-  end;
-
 fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
   let
     val typof = curry fastype_of1 bound_Ts;