src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
changeset 60001 0e1b220ec4c9
parent 59989 7b80ddb65e3e
child 60003 ba8fa0c38d66
--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Apr 10 12:44:41 2015 +0200
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Apr 10 14:03:18 2015 +0200
@@ -49,7 +49,7 @@
   val fold_rev_let_if_case: Proof.context -> (term list -> term -> 'a -> 'a) -> typ list ->
     term -> 'a -> 'a
   val massage_let_if_case: Proof.context -> (term -> bool) -> (typ list -> term -> term) ->
-    typ list -> term -> term
+    (typ list -> term -> unit) -> typ list -> term -> term
   val massage_nested_corec_call: Proof.context -> (term -> bool) ->
     (typ list -> typ -> typ -> term -> term) -> (typ list -> typ -> typ -> term -> term) ->
     typ list -> typ -> typ -> term -> term
@@ -243,11 +243,11 @@
     SOME {casex = Const (s', _), split_sels = _ :: _, ...} => SOME s'
   | _ => NONE);
 
-fun massage_let_if_case ctxt has_call massage_leaf bound_Ts t0 =
+fun massage_let_if_case ctxt has_call massage_leaf unexpected_call bound_Ts t0 =
   let
     val thy = Proof_Context.theory_of ctxt;
 
-    fun check_no_call t = if has_call t then unexpected_corec_call ctxt [t0] t else ();
+    fun check_no_call bound_Ts t = if has_call t then unexpected_call bound_Ts t else ();
 
     fun massage_abs bound_Ts 0 t = massage_rec bound_Ts t
       | massage_abs bound_Ts m (Abs (s, T, t)) = Abs (s, T, massage_abs (T :: bound_Ts) (m - 1) t)
@@ -264,7 +264,8 @@
             (dummy_branch' :: _, []) => dummy_branch'
           | (_, [branch']) => branch'
           | (_, branches') =>
-            Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches'))
+            Term.list_comb (If_const (typof (hd branches')) $ tap (check_no_call bound_Ts) obj,
+              branches'))
         | (c as Const (@{const_name case_prod}, _), arg :: args) =>
           massage_rec bound_Ts
             (unfold_splits_lets (Term.list_comb (c $ Envir.eta_long bound_Ts arg, args)))
@@ -287,7 +288,7 @@
                       val casex' = Const (c, branch_Ts' ---> map typof obj_leftovers ---> body_T');
                     in
                       Term.list_comb (casex',
-                        branches' @ tap (List.app check_no_call) obj_leftovers)
+                        branches' @ tap (List.app (check_no_call bound_Ts)) obj_leftovers)
                     end
                   else
                     massage_leaf bound_Ts t
@@ -304,6 +305,9 @@
       if Term.is_dummy_pattern t then Const (@{const_name undefined}, fastype_of t) else t)
   end;
 
+fun massage_let_if_case_corec ctxt has_call massage_leaf bound_Ts t0 =
+  massage_let_if_case ctxt has_call massage_leaf (K (unexpected_corec_call ctxt [t0])) bound_Ts t0;
+
 fun curried_type (Type (@{type_name fun}, [Type (@{type_name prod}, Ts), T])) = Ts ---> T;
 
 fun massage_nested_corec_call ctxt has_call massage_call massage_noncall bound_Ts U T t0 =
@@ -348,7 +352,7 @@
             (betapply (t, var))))
         end)
     and massage_any_call bound_Ts U T =
-      massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
+      massage_let_if_case_corec ctxt has_call (fn bound_Ts => fn t =>
         if has_call t then
           (case U of
             Type (s, Us) =>
@@ -384,8 +388,6 @@
           | _ => ill_formed_corec_call ctxt t)
         else
           massage_noncall bound_Ts U T t) bound_Ts;
-
-    val T = fastype_of1 (bound_Ts, t0);
   in
     (if has_call t0 then massage_any_call else massage_noncall) bound_Ts U T t0
   end;
@@ -399,12 +401,12 @@
 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   (case fastype_of1 (bound_Ts, t) of
     Type (s, Ts) =>
-    massage_let_if_case ctxt has_call (fn _ => fn t =>
+    massage_let_if_case_corec ctxt has_call (fn _ => fn t =>
       if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t
   | _ => raise Fail "expand_corec_code_rhs");
 
 fun massage_corec_code_rhs ctxt massage_ctr =
-  massage_let_if_case ctxt (K false)
+  massage_let_if_case_corec ctxt (K false)
     (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
 
 fun fold_rev_corec_code_rhs ctxt f =
@@ -883,7 +885,7 @@
       fun rewrite_end _ t = if has_call t then undef_const else t;
       fun rewrite_cont bound_Ts t =
         if has_call t then mk_tuple1_balanced bound_Ts (snd (strip_comb t)) else undef_const;
-      fun massage f _ = massage_let_if_case ctxt has_call f bound_Ts rhs_term
+      fun massage f _ = massage_let_if_case_corec ctxt has_call f bound_Ts rhs_term
         |> abs_tuple_balanced fun_args;
     in
       (massage rewrite_stop, massage rewrite_end, massage rewrite_cont)