--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sun Sep 01 16:38:04 2013 +1000
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sun Sep 01 10:45:54 2013 +0200
@@ -28,8 +28,11 @@
fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
+val free_name = try (fn Free (v, _) => v);
+val const_name = try (fn Const (v, _) => v);
+
fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
-fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else
+fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else
strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
val simp_attrs = @{attributes [simp]};
@@ -103,43 +106,34 @@
user_eqn = eqn'}
end;
-fun rewrite_map_arg funs_data get_indices y rec_type res_type =
+fun rewrite_map_arg funs_data get_indices rec_type res_type =
let
- val pT = HOLogic.mk_prodT (rec_type, res_type);
- val fstx = fst_const pT;
- val sndx = snd_const pT;
+ val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data));
+ val fun_name = #fun_name fun_data;
+ val ctr_pos = length (#left_args fun_data);
- val SOME ({fun_name, left_args, ...} :: _) =
- find_first (equal rec_type o #rec_type o hd) funs_data;
- val ctr_pos = length left_args;
+ val pT = HOLogic.mk_prodT (rec_type, res_type);
- fun subst _ d (t as Bound d') = t |> d = d' ? curry (op $) fstx
- | subst l d (Abs (v, T, b)) = Abs (v, if d < 0 then pT else T, subst l (d + 1) b)
- | subst l d t =
+ val maybe_suc = Option.map (fn x => x + 1);
+ fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
+ | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
+ | subst d t =
let val (u, vs) = strip_comb t in
- if try (fst o dest_Free) u = SOME fun_name then
- if l andalso length vs = ctr_pos then
- list_comb (sndx |> permute_args ctr_pos, vs)
- else if length vs <= ctr_pos then
- primrec_error_eqn "too few arguments in recursive call" t
- else if nth vs ctr_pos |> member (op =) [y, Bound d] then
- list_comb (sndx $ nth vs ctr_pos, nth_drop ctr_pos vs |> map (subst false d))
+ if free_name u = SOME fun_name then
+ if d = SOME ~1 andalso length vs = ctr_pos then
+ list_comb (permute_args ctr_pos (snd_const pT), vs)
+ else if length vs > ctr_pos andalso is_some d
+ andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
+ list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
else
- primrec_error_eqn "recursive call not directly applied to constructor argument" t
- else if try (fst o dest_Const) u = SOME @{const_name comp} then
- (hd vs |> get_indices |> null orelse
- primrec_error_eqn "recursive call not directly applied to constructor argument" t;
- list_comb
- (u |> map_types (strip_type #>> (fn Ts => Ts
- |> nth_map (length Ts - 1) (K pT)
- |> nth_map (length Ts - 2) (strip_type #>> nth_map 0 (K pT) #> (op --->)))
- #> (op --->)),
- nth_map 1 (subst l d) vs))
+ primrec_error_eqn ("recursive call not directly applied to constructor argument") t
+ else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then
+ list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs)
else
- list_comb (u, map (subst false d) vs)
+ list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
end
in
- subst true ~1
+ subst (SOME ~1)
end;
(* FIXME get rid of funs_data or get_indices *)
@@ -164,7 +158,7 @@
else if is_some maybe_indirect_y' then
(if contains_fun g then t else y)
|> massage_indirect_rec_call lthy contains_fun
- (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y')
+ (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y')
|> (if contains_fun g then I else curry (op $) g)
else
t
@@ -426,7 +420,7 @@
val fun_args = if is_none disc
then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd
else the disc |> the_single o snd o strip_comb
- |> (fn t => if try (fst o dest_Free o head_of) t = SOME fun_name
+ |> (fn t => if free_name (head_of t) = SOME fun_name
then snd (strip_comb t) else []);
val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};