merged
authortraytel
Sat, 31 Aug 2013 23:55:03 +0200
changeset 53354 b7469b85ca28
parent 53353 0c1c67e3fccc (current diff)
parent 53351 4335477c60f5 (diff)
child 53355 603e6e97c391
merged
src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sat Aug 31 23:49:36 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Sat Aug 31 23:55:03 2013 +0200
@@ -103,32 +103,6 @@
      user_eqn = eqn'}
   end;
 
-(* substitutes (f ls x rs) by (y ls rs) for all f: get_idx f \<ge> 0, (x,y) \<in> substs *)
-fun subst_direct_calls get_idx get_ctr_pos substs = 
-  let
-    fun subst (Abs (v, T, b)) = Abs (v, T, subst b)
-      | subst t =
-        let
-          val (f, args) = strip_comb t;
-          val idx = get_idx f;
-          val ctr_pos  = if idx >= 0 then get_ctr_pos idx else ~1;
-        in
-          if idx < 0 then
-            list_comb (f, map subst args)
-          else if ctr_pos >= length args then
-            primrec_error_eqn "too few arguments in recursive call" t
-          else
-            let
-              val (key, repl) = the (find_first (equal (nth args ctr_pos) o fst) substs)
-                handle Option.Option => primrec_error_eqn
-                  "recursive call not directly applied to constructor argument" t;
-            in
-              remove (op =) key args |> map subst |> curry list_comb repl
-            end
-        end
-  in subst end;
-
-(* FIXME get rid of funs_data or get_indices *)
 fun rewrite_map_arg funs_data get_indices y rec_type res_type =
   let
     val pT = HOLogic.mk_prodT (rec_type, res_type);
@@ -169,36 +143,41 @@
   end;
 
 (* FIXME get rid of funs_data or get_indices *)
-fun subst_indirect_call lthy funs_data get_indices (y, y') =
+fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
   let
-    fun massage massage_map_arg bound_Ts =
-      massage_indirect_rec_call lthy (not o null o get_indices) massage_map_arg bound_Ts y y';
-    fun subst bound_Ts (t as _ $ _) =
+    val contains_fun = not o null o get_indices;
+    fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
+      | subst bound_Ts (t as g $ y) =
         let
-          val ctr_args = fold_aterms (curry (op @) o get_indices) t []
-            |> maps (maps #ctr_args o nth funs_data);
-          val (f', args') = strip_comb t;
-          val fun_arg_idx = find_index (exists_subterm (not o null o get_indices)) args';
-          val arg_idx = find_index (exists_subterm (equal y)) args';
-          val (f, args) = chop (arg_idx + 1) args' |>> curry list_comb f';
-          val _ = fun_arg_idx < 0 orelse arg_idx >= 0 orelse
-            exists (exists_subterm (member (op =) ctr_args)) args' orelse
-            primrec_error_eqn "recursive call not applied to constructor argument" t;
+          val is_ctr_arg = exists (exists (exists (equal y) o #ctr_args)) funs_data;
+          val maybe_direct_y' = AList.lookup (op =) direct_calls y;
+          val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
+          val (g_head, g_args) = strip_comb g;
         in
-          if fun_arg_idx <> arg_idx andalso fun_arg_idx >= 0 andalso arg_idx >= 0 then
-            if nth args' arg_idx = y then
-              list_comb (massage (rewrite_map_arg funs_data get_indices y) bound_Ts f, args)
-            else
-              primrec_error_eqn "recursive call not directly applied to constructor argument" f
+          if not is_ctr_arg then
+            pairself (subst bound_Ts) (g, y) |> (op $)
+          else if contains_fun g_head then
+            (length g_args >= the (funs_data |> get_first (fn {fun_name, left_args, ...} :: _ =>
+              if fst (dest_Free g_head) = fun_name then SOME (length left_args) else NONE)) (*###*)
+                orelse primrec_error_eqn "too few arguments in recursive call" t;
+            list_comb (the maybe_direct_y', g_args))
+          else if is_some maybe_indirect_y' then
+            (if contains_fun g then t else y)
+            |> massage_indirect_rec_call lthy contains_fun
+              (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y')
+            |> (if contains_fun g then I else curry (op $) g)
           else
-            list_comb (f', map (subst bound_Ts) args')
+            t
         end
-      | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
-      | subst bound_Ts t = t |> t = y ? massage (K I |> K) bound_Ts;
-  in subst [] end;
+      | subst _ t = t
+  in
+    subst [] t
+    |> (fn u => ((contains_fun u andalso (* FIXME detect this case earlier *)
+      primrec_error_eqn "recursive call not directly applied to constructor argument" t); u))
+  end;
 
 fun build_rec_arg lthy get_indices funs_data ctr_spec maybe_eqn_data =
-  if is_some maybe_eqn_data then
+  if is_none maybe_eqn_data then Const (@{const_name undefined}, dummyT) else
     let
       val eqn_data = the maybe_eqn_data;
       val t = #rhs_term eqn_data;
@@ -241,17 +220,12 @@
       val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
       val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
 
-      val get_idx = (fn Free (v, _) => find_index (equal v o #fun_name o hd) funs_data | _ => ~1);
-
-      val t' = t
-        |> fold (subst_indirect_call lthy funs_data get_indices) indirect_calls
-        |> subst_direct_calls get_idx (length o #left_args o hd o nth funs_data) direct_calls;
-
       val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
     in
-      t' |> fold_rev absfree abstractions
-    end
-  else Const (@{const_name undefined}, dummyT)
+      t
+      |> subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls
+      |> fold_rev absfree abstractions
+    end;
 
 fun build_defs lthy bs mxs funs_data rec_specs get_indices =
   let