src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
changeset 60002 50cf9e0ae818
parent 60001 0e1b220ec4c9
child 61760 1647bb489522
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Fri Apr 10 14:03:18 2015 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Fri Apr 10 14:44:08 2015 +0200
@@ -9,7 +9,7 @@
 signature BNF_LFP_REC_SUGAR_MORE =
 sig
   val massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
-    typ list -> term -> term -> term -> term
+    (typ * typ -> term) -> typ list -> term -> term -> term -> term
 end;
 
 structure BNF_LFP_Rec_Sugar_More : BNF_LFP_REC_SUGAR_MORE =
@@ -77,17 +77,17 @@
 fun unexpected_rec_call ctxt eqns t =
   error_at ctxt eqns ("Unexpected recursive call in " ^ quote (Syntax.string_of_term ctxt t));
 
-fun massage_nested_rec_call ctxt has_call massage_fun bound_Ts y y' t0 =
+fun massage_nested_rec_call ctxt has_call massage_fun massage_nonfun bound_Ts y y' t0 =
   let
     fun check_no_call t = if has_call t then unexpected_rec_call ctxt [t0] t else ();
 
     val typof = curry fastype_of1 bound_Ts;
-    val build_map_fst = build_map ctxt [] (fst_const o fst);
+    val massage_no_call = build_map ctxt [] massage_nonfun;
 
     val yT = typof y;
     val yU = typof y';
 
-    fun y_of_y' () = build_map_fst (yU, yT) $ y';
+    fun y_of_y' () = massage_no_call (yU, yT) $ y';
     val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
 
     fun massage_mutual_fun U T t =
@@ -95,7 +95,7 @@
         Const (@{const_name comp}, _) $ t1 $ t2 =>
         mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
       | _ =>
-        if has_call t then massage_fun U T t else mk_comp bound_Ts (t, build_map_fst (U, T)));
+        if has_call t then massage_fun U T t else mk_comp bound_Ts (t, massage_no_call (U, T)));
 
     fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
         (case try (dest_map ctxt s) t of
@@ -168,7 +168,7 @@
   end;
 
 fun rewrite_nested_rec_call ctxt has_call get_ctr_pos =
-  massage_nested_rec_call ctxt has_call (rewrite_map_fun ctxt get_ctr_pos);
+  massage_nested_rec_call ctxt has_call (rewrite_map_fun ctxt get_ctr_pos) (fst_const o fst);
 
 val _ = Theory.setup (register_lfp_rec_extension
   {nested_simps = nested_simps, is_new_datatype = is_new_datatype,