set stage for more flexible 'primrec' syntax for recursion through functions
authorblanchet
Fri, 18 Oct 2013 18:58:46 +0200
changeset 54159 eb5d58c99049
parent 54158 0af35cebe8ca
child 54161 496f9af15b39
set stage for more flexible 'primrec' syntax for recursion through functions
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
src/HOL/BNF/Tools/bnf_fp_util.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Oct 18 17:47:25 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Fri Oct 18 18:58:46 2013 +0200
@@ -232,7 +232,6 @@
   | flat_corec_preds_predsss_gettersss (p :: ps) (qss :: qsss) (fss :: fsss) =
     p :: flat_corec_predss_getterss qss fss @ flat_corec_preds_predsss_gettersss ps qsss fsss;
 
-fun mk_tupled_fun x f xs = HOLogic.tupled_lambda x (Term.list_comb (f, xs));
 fun mk_uncurried2_fun f xss =
   mk_tupled_fun (HOLogic.mk_tuple (map HOLogic.mk_tuple xss)) f (flat_rec_arg_args xss);
 
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 18 17:47:25 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 18 18:58:46 2013 +0200
@@ -172,27 +172,30 @@
   let
     fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
       | subst bound_Ts (t as g' $ y) =
-        if not (member (op =) ctr_args y) then
-          pairself (subst bound_Ts) (g', y) |> op $
-        else
-          let
-            val maybe_mutual_y' = AList.lookup (op =) mutual_calls y;
-            val maybe_nested_y' = AList.lookup (op =) nested_calls y;
-            val (g, g_args) = strip_comb g';
-            val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
-            val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
-              primrec_error_eqn "too few arguments in recursive call" t;
-          in
-            if ctr_pos >= 0 then
-              list_comb (the maybe_mutual_y', g_args)
-            else if is_some maybe_nested_y' then
-              (if has_call g' then t else y)
-              |> massage_nested_rec_call lthy has_call
-                (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_nested_y')
-              |> (if has_call g' then I else curry (op $) g')
-            else
-              t
-          end
+        let val y_head = head_of y in
+          if not (member (op =) ctr_args y_head) then
+            pairself (subst bound_Ts) (g', y) |> op $
+          else
+            let
+              val maybe_mutual_y' = AList.lookup (op =) mutual_calls y;
+              val maybe_nested_y_head' = AList.lookup (op =) nested_calls y_head;
+              val (g, g_args) = strip_comb g';
+              val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
+              val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
+                primrec_error_eqn "too few arguments in recursive call" t;
+            in
+              if ctr_pos >= 0 then
+                list_comb (the maybe_mutual_y', g_args)
+              else if is_some maybe_nested_y_head' then
+                (if has_call g' then t else y)
+                |> massage_nested_rec_call lthy has_call
+                  (rewrite_map_arg get_ctr_pos) bound_Ts y_head (the maybe_nested_y_head')
+                |> (if has_call g' then I else curry (op $) g')
+              else
+                t
+            end
+            |> tap (fn t => tracing ("*** " ^ Syntax.string_of_term lthy t)) (*###*)
+        end
       | subst _ t = t
   in
     subst [] t
@@ -582,7 +585,7 @@
       else apfst SOME (dissect_coeqn_disc seq fun_names ctr_specss
           (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
 
-    val sel_concls = (sels ~~ ctr_args)
+    val sel_concls = sels ~~ ctr_args
       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
 
 (*
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Fri Oct 18 17:47:25 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Fri Oct 18 18:58:46 2013 +0200
@@ -215,6 +215,19 @@
     SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt
   | NONE => invalid_map ctxt);
 
+fun expand_to_comp bound_Ts f t =
+  let
+    val (g, xs) = Term.strip_comb t;
+    val m = length xs;
+    val j = Term.maxidx_of_term t;
+    val us = map2 (fn k => fn x => Var ((Name.uu, j + k), fastype_of1 (bound_Ts, x))) (1 upto m) xs;
+    val u_tuple = HOLogic.mk_tuple us;
+    val unc_g = mk_tupled_fun u_tuple g us;
+    val x_tuple = HOLogic.mk_tuple xs;
+  in
+    (HOLogic.mk_comp (f, unc_g), x_tuple)
+  end;
+
 fun map_flattened_map_args ctxt s map_args fs =
   let
     val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
@@ -261,6 +274,11 @@
         if t2 = y then
           massage_map yU yT (elim_y t1) $ y'
           handle AINT_NO_MAP t' => invalid_map ctxt t'
+        else if head_of t2 = y then
+          let val (u1, u2) = expand_to_comp bound_Ts t1 t2 in
+            if has_call u2 then unexpected_rec_call ctxt u2
+            else massage_call u1 $ u2
+          end
         else
           ill_formed_rec_call ctxt t
       | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
@@ -446,17 +464,17 @@
     if has_call t then massage_call bound_Ts U T t else build_map_Inl (T, U) $ t
   end;
 
-fun expand_ctr_term ctxt s Ts t =
+fun expand_to_ctr_term ctxt s Ts t =
   (case ctr_sugar_of ctxt s of
     SOME {ctrs, casex, ...} =>
     Term.list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t
-  | NONE => raise Fail "expand_ctr_term");
+  | NONE => raise Fail "expand_to_ctr_term");
 
 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   (case fastype_of1 (bound_Ts, t) of
     Type (s, Ts) =>
     massage_let_if_case ctxt has_call (fn _ => fn t =>
-      if can (dest_ctr ctxt s) t then t else expand_ctr_term ctxt s Ts t) bound_Ts t
+      if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t
   | _ => raise Fail "expand_corec_code_rhs");
 
 fun massage_corec_code_rhs ctxt massage_ctr =
--- a/src/HOL/BNF/Tools/bnf_fp_util.ML	Fri Oct 18 17:47:25 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_util.ML	Fri Oct 18 18:58:46 2013 +0200
@@ -139,6 +139,7 @@
   val mk_sumTN_balanced: typ list -> typ
 
   val mk_convol: term * term -> term
+  val mk_tupled_fun: term -> term -> term list -> term
 
   val Inl_const: typ -> typ -> term
   val Inr_const: typ -> typ -> term
@@ -381,6 +382,9 @@
     val convolT = fTU --> gTU --> gT --> HOLogic.mk_prodT (fU, gU);
   in Const (@{const_name convol}, convolT) $ f $ g end;
 
+fun mk_tupled_fun x f xs =
+  if xs = [x] then f else HOLogic.tupled_lambda x (Term.list_comb (f, xs));
+
 fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
 fun mk_Inl RT t = Inl_const (fastype_of t) RT $ t;