generalized code
authorblanchet
Fri, 10 Apr 2015 14:03:18 +0200
changeset 60001 0e1b220ec4c9
parent 60000 b0816837ef4b
child 60002 50cf9e0ae818
generalized code
src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Apr 10 12:44:41 2015 +0200
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Apr 10 14:03:18 2015 +0200
@@ -49,7 +49,7 @@
   val fold_rev_let_if_case: Proof.context -> (term list -> term -> 'a -> 'a) -> typ list ->
     term -> 'a -> 'a
   val massage_let_if_case: Proof.context -> (term -> bool) -> (typ list -> term -> term) ->
-    typ list -> term -> term
+    (typ list -> term -> unit) -> typ list -> term -> term
   val massage_nested_corec_call: Proof.context -> (term -> bool) ->
     (typ list -> typ -> typ -> term -> term) -> (typ list -> typ -> typ -> term -> term) ->
     typ list -> typ -> typ -> term -> term
@@ -243,11 +243,11 @@
     SOME {casex = Const (s', _), split_sels = _ :: _, ...} => SOME s'
   | _ => NONE);
 
-fun massage_let_if_case ctxt has_call massage_leaf bound_Ts t0 =
+fun massage_let_if_case ctxt has_call massage_leaf unexpected_call bound_Ts t0 =
   let
     val thy = Proof_Context.theory_of ctxt;
 
-    fun check_no_call t = if has_call t then unexpected_corec_call ctxt [t0] t else ();
+    fun check_no_call bound_Ts t = if has_call t then unexpected_call bound_Ts t else ();
 
     fun massage_abs bound_Ts 0 t = massage_rec bound_Ts t
       | massage_abs bound_Ts m (Abs (s, T, t)) = Abs (s, T, massage_abs (T :: bound_Ts) (m - 1) t)
@@ -264,7 +264,8 @@
             (dummy_branch' :: _, []) => dummy_branch'
           | (_, [branch']) => branch'
           | (_, branches') =>
-            Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches'))
+            Term.list_comb (If_const (typof (hd branches')) $ tap (check_no_call bound_Ts) obj,
+              branches'))
         | (c as Const (@{const_name case_prod}, _), arg :: args) =>
           massage_rec bound_Ts
             (unfold_splits_lets (Term.list_comb (c $ Envir.eta_long bound_Ts arg, args)))
@@ -287,7 +288,7 @@
                       val casex' = Const (c, branch_Ts' ---> map typof obj_leftovers ---> body_T');
                     in
                       Term.list_comb (casex',
-                        branches' @ tap (List.app check_no_call) obj_leftovers)
+                        branches' @ tap (List.app (check_no_call bound_Ts)) obj_leftovers)
                     end
                   else
                     massage_leaf bound_Ts t
@@ -304,6 +305,9 @@
       if Term.is_dummy_pattern t then Const (@{const_name undefined}, fastype_of t) else t)
   end;
 
+fun massage_let_if_case_corec ctxt has_call massage_leaf bound_Ts t0 =
+  massage_let_if_case ctxt has_call massage_leaf (K (unexpected_corec_call ctxt [t0])) bound_Ts t0;
+
 fun curried_type (Type (@{type_name fun}, [Type (@{type_name prod}, Ts), T])) = Ts ---> T;
 
 fun massage_nested_corec_call ctxt has_call massage_call massage_noncall bound_Ts U T t0 =
@@ -348,7 +352,7 @@
             (betapply (t, var))))
         end)
     and massage_any_call bound_Ts U T =
-      massage_let_if_case ctxt has_call (fn bound_Ts => fn t =>
+      massage_let_if_case_corec ctxt has_call (fn bound_Ts => fn t =>
         if has_call t then
           (case U of
             Type (s, Us) =>
@@ -384,8 +388,6 @@
           | _ => ill_formed_corec_call ctxt t)
         else
           massage_noncall bound_Ts U T t) bound_Ts;
-
-    val T = fastype_of1 (bound_Ts, t0);
   in
     (if has_call t0 then massage_any_call else massage_noncall) bound_Ts U T t0
   end;
@@ -399,12 +401,12 @@
 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   (case fastype_of1 (bound_Ts, t) of
     Type (s, Ts) =>
-    massage_let_if_case ctxt has_call (fn _ => fn t =>
+    massage_let_if_case_corec ctxt has_call (fn _ => fn t =>
       if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t
   | _ => raise Fail "expand_corec_code_rhs");
 
 fun massage_corec_code_rhs ctxt massage_ctr =
-  massage_let_if_case ctxt (K false)
+  massage_let_if_case_corec ctxt (K false)
     (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
 
 fun fold_rev_corec_code_rhs ctxt f =
@@ -883,7 +885,7 @@
       fun rewrite_end _ t = if has_call t then undef_const else t;
       fun rewrite_cont bound_Ts t =
         if has_call t then mk_tuple1_balanced bound_Ts (snd (strip_comb t)) else undef_const;
-      fun massage f _ = massage_let_if_case ctxt has_call f bound_Ts rhs_term
+      fun massage f _ = massage_let_if_case_corec ctxt has_call f bound_Ts rhs_term
         |> abs_tuple_balanced fun_args;
     in
       (massage rewrite_stop, massage rewrite_end, massage rewrite_cont)
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Fri Apr 10 12:44:41 2015 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML	Fri Apr 10 14:03:18 2015 +0200
@@ -74,12 +74,12 @@
   error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
 fun invalid_map ctxt t =
   error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
-fun unexpected_rec_call ctxt t =
-  error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
+fun unexpected_rec_call ctxt eqns t =
+  error_at ctxt eqns ("Unexpected recursive call in " ^ quote (Syntax.string_of_term ctxt t));
 
-fun massage_nested_rec_call ctxt has_call massage_fun bound_Ts y y' =
+fun massage_nested_rec_call ctxt has_call massage_fun bound_Ts y y' t0 =
   let
-    fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
+    fun check_no_call t = if has_call t then unexpected_rec_call ctxt [t0] t else ();
 
     val typof = curry fastype_of1 bound_Ts;
     val build_map_fst = build_map ctxt [] (fst_const o fst);
@@ -95,12 +95,7 @@
         Const (@{const_name comp}, _) $ t1 $ t2 =>
         mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
       | _ =>
-        if has_call t then
-          (case try HOLogic.dest_prodT U of
-            SOME (U1, U2) => if U1 = T then massage_fun T U2 t else invalid_map ctxt t
-          | NONE => invalid_map ctxt t)
-        else
-          mk_comp bound_Ts (t, build_map_fst (U, T)));
+        if has_call t then massage_fun U T t else mk_comp bound_Ts (t, build_map_fst (U, T)));
 
     fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
         (case try (dest_map ctxt s) t of
@@ -121,7 +116,7 @@
         massage_map U T t
         handle NO_MAP _ => massage_mutual_fun U T t;
 
-    fun massage_call (t as t1 $ t2) =
+    fun massage_outer_call (t as t1 $ t2) =
         if has_call t then
           if t2 = y then
             massage_map yU yT (elim_y t1) $ y'
@@ -129,25 +124,28 @@
           else
             let val (g, xs) = Term.strip_comb t2 in
               if g = y then
-                if exists has_call xs then unexpected_rec_call ctxt t2
-                else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
+                if exists has_call xs then unexpected_rec_call ctxt [t0] t2
+                else Term.list_comb (massage_outer_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
               else
                 ill_formed_rec_call ctxt t
             end
         else
           elim_y t
-      | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
+      | massage_outer_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
   in
-    massage_call
+    massage_outer_call t0
   end;
 
-fun rewrite_map_arg ctxt get_ctr_pos rec_type res_type =
+fun rewrite_map_fun ctxt get_ctr_pos U T t =
   let
-    val pT = HOLogic.mk_prodT (rec_type, res_type);
+    val _ =
+      (case try HOLogic.dest_prodT U of
+        SOME (U1, _) => U1 = T orelse invalid_map ctxt t
+      | NONE => invalid_map ctxt t);
 
-    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
+    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const U)
       | subst d (Abs (v, T, b)) =
-        Abs (v, if d = SOME ~1 then pT else T, subst (Option.map (Integer.add 1) d) b)
+        Abs (v, if d = SOME ~1 then U else T, subst (Option.map (Integer.add 1) d) b)
       | subst d t =
         let
           val (u, vs) = strip_comb t;
@@ -155,22 +153,22 @@
         in
           if ctr_pos >= 0 then
             if d = SOME ~1 andalso length vs = ctr_pos then
-              Term.list_comb (permute_args ctr_pos (snd_const pT), vs)
+              Term.list_comb (permute_args ctr_pos (snd_const U), vs)
             else if length vs > ctr_pos andalso is_some d andalso
                 d = try (fn Bound n => n) (nth vs ctr_pos) then
-              Term.list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
+              Term.list_comb (snd_const U $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
             else
               error ("Recursive call not directly applied to constructor argument in " ^
                 quote (Syntax.string_of_term ctxt t))
           else
-            Term.list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
-        end
+            Term.list_comb (u, map (subst (if d = SOME ~1 then NONE else d)) vs)
+        end;
   in
-    subst (SOME ~1)
+    subst (SOME ~1) t
   end;
 
 fun rewrite_nested_rec_call ctxt has_call get_ctr_pos =
-  massage_nested_rec_call ctxt has_call (rewrite_map_arg ctxt get_ctr_pos);
+  massage_nested_rec_call ctxt has_call (rewrite_map_fun ctxt get_ctr_pos);
 
 val _ = Theory.setup (register_lfp_rec_extension
   {nested_simps = nested_simps, is_new_datatype = is_new_datatype,