--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sun Sep 01 10:45:54 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sun Sep 01 14:00:05 2013 +0200
@@ -28,11 +28,15 @@
fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
+fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
+
val free_name = try (fn Free (v, _) => v);
val const_name = try (fn Const (v, _) => v);
+val undef_const = Const (@{const_name undefined}, dummyT);
-fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
-fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else
+fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
+ |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
+fun abs_tuple t = if t = undef_const then t else
strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
val simp_attrs = @{attributes [simp]};
@@ -53,9 +57,6 @@
user_eqn: term
};
-fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
- |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
-
fun dissect_eqn lthy fun_names eqn' =
let
val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
@@ -106,20 +107,21 @@
user_eqn = eqn'}
end;
-fun rewrite_map_arg funs_data get_indices rec_type res_type =
+fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type =
let
- val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data));
- val fun_name = #fun_name fun_data;
- val ctr_pos = length (#left_args fun_data);
-
val pT = HOLogic.mk_prodT (rec_type, res_type);
val maybe_suc = Option.map (fn x => x + 1);
fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
| subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
| subst d t =
- let val (u, vs) = strip_comb t in
- if free_name u = SOME fun_name then
+ let
+ val (u, vs) = strip_comb t;
+ val maybe_fun_name_ctr_pos =
+ find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list;
+ val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos;
+ in
+ if is_some maybe_fun_name_ctr_pos then
if d = SOME ~1 andalso length vs = ctr_pos then
list_comb (permute_args ctr_pos (snd_const pT), vs)
else if length vs > ctr_pos andalso is_some d
@@ -136,42 +138,40 @@
subst (SOME ~1)
end;
-(* FIXME get rid of funs_data or get_indices *)
-fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
+fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t =
let
- val contains_fun = not o null o get_indices;
fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
- | subst bound_Ts (t as g $ y) =
+ | subst bound_Ts (t as g' $ y) =
let
- val is_ctr_arg = exists (exists (exists (equal y) o #ctr_args)) funs_data;
val maybe_direct_y' = AList.lookup (op =) direct_calls y;
val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
- val (g_head, g_args) = strip_comb g;
+ val (g, g_args) = strip_comb g';
+ val maybe_ctr_pos =
+ try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list;
+ val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse
+ primrec_error_eqn "too few arguments in recursive call" t;
in
- if not is_ctr_arg then
- pairself (subst bound_Ts) (g, y) |> (op $)
- else if contains_fun g_head then
- (length g_args >= the (funs_data |> get_first (fn {fun_name, left_args, ...} :: _ =>
- if fst (dest_Free g_head) = fun_name then SOME (length left_args) else NONE)) (*###*)
- orelse primrec_error_eqn "too few arguments in recursive call" t;
- list_comb (the maybe_direct_y', g_args))
+ if not (member (op =) ctr_args y) then
+ pairself (subst bound_Ts) (g', y) |> (op $)
+ else if is_some maybe_ctr_pos then
+ list_comb (the maybe_direct_y', g_args)
else if is_some maybe_indirect_y' then
- (if contains_fun g then t else y)
- |> massage_indirect_rec_call lthy contains_fun
- (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y')
- |> (if contains_fun g then I else curry (op $) g)
+ (if has_call g' then t else y)
+ |> massage_indirect_rec_call lthy has_call
+ (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y')
+ |> (if has_call g' then I else curry (op $) g')
else
t
end
| subst _ t = t
in
subst [] t
- |> (fn u => ((contains_fun u andalso (* FIXME detect this case earlier *)
- primrec_error_eqn "recursive call not directly applied to constructor argument" t); u))
+ |> tap (fn u => has_call u andalso (* FIXME detect this case earlier *)
+ primrec_error_eqn "recursive call not directly applied to constructor argument" t)
end;
-fun build_rec_arg lthy get_indices funs_data ctr_spec maybe_eqn_data =
- if is_none maybe_eqn_data then Const (@{const_name undefined}, dummyT) else
+fun build_rec_arg lthy funs_data has_call ctr_spec maybe_eqn_data =
+ if is_none maybe_eqn_data then undef_const else
let
val eqn_data = the maybe_eqn_data;
val t = #rhs_term eqn_data;
@@ -215,13 +215,15 @@
val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
+ val fun_name_ctr_pos_list =
+ map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
in
t
- |> subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls
+ |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls
|> fold_rev absfree abstractions
end;
-fun build_defs lthy bs mxs funs_data rec_specs get_indices =
+fun build_defs lthy bs mxs funs_data rec_specs has_call =
let
val n_funs = length funs_data;
@@ -239,7 +241,7 @@
val recs = take n_funs rec_specs |> map #recx;
val rec_args = ctr_spec_eqn_data_list
|> sort ((op <) o pairself (#offset o fst) |> make_ord)
- |> map (uncurry (build_rec_arg lthy get_indices funs_data) o apsnd (try the_single));
+ |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
val ctr_poss = map (fn x =>
if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
primrec_error ("inconstant constructor pattern position for function " ^
@@ -253,7 +255,7 @@
|> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
end;
-fun find_rec_calls get_indices eqn_data =
+fun find_rec_calls has_call eqn_data =
let
fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
| find (t as _ $ _) ctr_arg =
@@ -265,7 +267,7 @@
find f' ctr_arg @ maps (fn x => find x ctr_arg) args'
else
let val (f, args) = chop n args' |>> curry list_comb f' in
- if exists_subterm (not o null o get_indices) f then
+ if has_call f then
f :: maps (fn x => find x ctr_arg) args
else
find f ctr_arg @ maps (fn x => find x ctr_arg) args
@@ -288,16 +290,16 @@
|> map (fn (x, y) => the_single y handle List.Empty =>
primrec_error ("missing equations for function " ^ quote x));
- fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
- |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
- |> map_filter I;
-
+ val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
val arg_Ts = map (#rec_type o hd) funs_data;
val res_Ts = map (#res_type o hd) funs_data;
val callssss = funs_data
|> map (partition_eq ((op =) o pairself #ctr))
- |> map (maps (map_filter (find_rec_calls get_indices)));
+ |> map (maps (map_filter (find_rec_calls has_call)));
+ fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
+ |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
+ |> map_filter I;
val ((nontriv, rec_specs, _, induct_thm, induct_thms), lthy') =
rec_specs_of bs arg_Ts res_Ts get_indices callssss lthy;
@@ -308,7 +310,7 @@
primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
" is not a constructor in left-hand side") user_eqn) eqns_data end;
- val defs = build_defs lthy' bs mxs funs_data rec_specs get_indices;
+ val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data
lthy =
@@ -561,7 +563,7 @@
0 upto length ctr_specs - 1
|> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
|> Option.map #cond
- |> the_default (Const (@{const_name undefined}, dummyT)))
+ |> the_default undef_const)
|> fst o split_last;
in
(* FIXME: deal with #preds above *)
@@ -633,7 +635,7 @@
val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list;
val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @
map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0;
- val corec_args = replicate n_args (Const (@{const_name undefined}, dummyT))
+ val corec_args = replicate n_args undef_const
|> fold2 build_corec_args_discs disc_eqnss ctr_specss
|> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss;