--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sat Aug 31 23:49:36 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sat Aug 31 23:55:03 2013 +0200
@@ -103,32 +103,6 @@
user_eqn = eqn'}
end;
-(* substitutes (f ls x rs) by (y ls rs) for all f: get_idx f \<ge> 0, (x,y) \<in> substs *)
-fun subst_direct_calls get_idx get_ctr_pos substs =
- let
- fun subst (Abs (v, T, b)) = Abs (v, T, subst b)
- | subst t =
- let
- val (f, args) = strip_comb t;
- val idx = get_idx f;
- val ctr_pos = if idx >= 0 then get_ctr_pos idx else ~1;
- in
- if idx < 0 then
- list_comb (f, map subst args)
- else if ctr_pos >= length args then
- primrec_error_eqn "too few arguments in recursive call" t
- else
- let
- val (key, repl) = the (find_first (equal (nth args ctr_pos) o fst) substs)
- handle Option.Option => primrec_error_eqn
- "recursive call not directly applied to constructor argument" t;
- in
- remove (op =) key args |> map subst |> curry list_comb repl
- end
- end
- in subst end;
-
-(* FIXME get rid of funs_data or get_indices *)
fun rewrite_map_arg funs_data get_indices y rec_type res_type =
let
val pT = HOLogic.mk_prodT (rec_type, res_type);
@@ -169,36 +143,41 @@
end;
(* FIXME get rid of funs_data or get_indices *)
-fun subst_indirect_call lthy funs_data get_indices (y, y') =
+fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
let
- fun massage massage_map_arg bound_Ts =
- massage_indirect_rec_call lthy (not o null o get_indices) massage_map_arg bound_Ts y y';
- fun subst bound_Ts (t as _ $ _) =
+ val contains_fun = not o null o get_indices;
+ fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
+ | subst bound_Ts (t as g $ y) =
let
- val ctr_args = fold_aterms (curry (op @) o get_indices) t []
- |> maps (maps #ctr_args o nth funs_data);
- val (f', args') = strip_comb t;
- val fun_arg_idx = find_index (exists_subterm (not o null o get_indices)) args';
- val arg_idx = find_index (exists_subterm (equal y)) args';
- val (f, args) = chop (arg_idx + 1) args' |>> curry list_comb f';
- val _ = fun_arg_idx < 0 orelse arg_idx >= 0 orelse
- exists (exists_subterm (member (op =) ctr_args)) args' orelse
- primrec_error_eqn "recursive call not applied to constructor argument" t;
+ val is_ctr_arg = exists (exists (exists (equal y) o #ctr_args)) funs_data;
+ val maybe_direct_y' = AList.lookup (op =) direct_calls y;
+ val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
+ val (g_head, g_args) = strip_comb g;
in
- if fun_arg_idx <> arg_idx andalso fun_arg_idx >= 0 andalso arg_idx >= 0 then
- if nth args' arg_idx = y then
- list_comb (massage (rewrite_map_arg funs_data get_indices y) bound_Ts f, args)
- else
- primrec_error_eqn "recursive call not directly applied to constructor argument" f
+ if not is_ctr_arg then
+ pairself (subst bound_Ts) (g, y) |> (op $)
+ else if contains_fun g_head then
+ (length g_args >= the (funs_data |> get_first (fn {fun_name, left_args, ...} :: _ =>
+ if fst (dest_Free g_head) = fun_name then SOME (length left_args) else NONE)) (*###*)
+ orelse primrec_error_eqn "too few arguments in recursive call" t;
+ list_comb (the maybe_direct_y', g_args))
+ else if is_some maybe_indirect_y' then
+ (if contains_fun g then t else y)
+ |> massage_indirect_rec_call lthy contains_fun
+ (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y')
+ |> (if contains_fun g then I else curry (op $) g)
else
- list_comb (f', map (subst bound_Ts) args')
+ t
end
- | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
- | subst bound_Ts t = t |> t = y ? massage (K I |> K) bound_Ts;
- in subst [] end;
+ | subst _ t = t
+ in
+ subst [] t
+ |> (fn u => ((contains_fun u andalso (* FIXME detect this case earlier *)
+ primrec_error_eqn "recursive call not directly applied to constructor argument" t); u))
+ end;
fun build_rec_arg lthy get_indices funs_data ctr_spec maybe_eqn_data =
- if is_some maybe_eqn_data then
+ if is_none maybe_eqn_data then Const (@{const_name undefined}, dummyT) else
let
val eqn_data = the maybe_eqn_data;
val t = #rhs_term eqn_data;
@@ -241,17 +220,12 @@
val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
- val get_idx = (fn Free (v, _) => find_index (equal v o #fun_name o hd) funs_data | _ => ~1);
-
- val t' = t
- |> fold (subst_indirect_call lthy funs_data get_indices) indirect_calls
- |> subst_direct_calls get_idx (length o #left_args o hd o nth funs_data) direct_calls;
-
val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
in
- t' |> fold_rev absfree abstractions
- end
- else Const (@{const_name undefined}, dummyT)
+ t
+ |> subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls
+ |> fold_rev absfree abstractions
+ end;
fun build_defs lthy bs mxs funs_data rec_specs get_indices =
let