--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Tue Sep 03 21:46:42 2013 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Wed Sep 04 02:11:50 2013 +0200
@@ -36,8 +36,7 @@
fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
|> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
-fun abs_tuple t = if t = undef_const then t else
- strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
+val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
val simp_attrs = @{attributes [simp]};
@@ -107,7 +106,7 @@
user_eqn = eqn'}
end;
-fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type =
+fun rewrite_map_arg get_ctr_pos rec_type res_type =
let
val pT = HOLogic.mk_prodT (rec_type, res_type);
@@ -117,11 +116,9 @@
| subst d t =
let
val (u, vs) = strip_comb t;
- val maybe_fun_name_ctr_pos =
- find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list;
- val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos;
+ val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
in
- if is_some maybe_fun_name_ctr_pos then
+ if ctr_pos >= 0 then
if d = SOME ~1 andalso length vs = ctr_pos then
list_comb (permute_args ctr_pos (snd_const pT), vs)
else if length vs > ctr_pos andalso is_some d
@@ -138,7 +135,7 @@
subst (SOME ~1)
end;
-fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t =
+fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
let
fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
| subst bound_Ts (t as g' $ y) =
@@ -146,19 +143,18 @@
val maybe_direct_y' = AList.lookup (op =) direct_calls y;
val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
val (g, g_args) = strip_comb g';
- val maybe_ctr_pos =
- try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list;
- val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse
+ val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
+ val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
primrec_error_eqn "too few arguments in recursive call" t;
in
if not (member (op =) ctr_args y) then
pairself (subst bound_Ts) (g', y) |> (op $)
- else if is_some maybe_ctr_pos then
+ else if ctr_pos >= 0 then
list_comb (the maybe_direct_y', g_args)
else if is_some maybe_indirect_y' then
(if has_call g' then t else y)
|> massage_indirect_rec_call lthy has_call
- (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y')
+ (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
|> (if has_call g' then I else curry (op $) g')
else
t
@@ -211,16 +207,17 @@
nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
indirect_calls';
+ val fun_name_ctr_pos_list =
+ map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
+ val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
- val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data);
- val fun_name_ctr_pos_list =
- map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
+ val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
in
t
- |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls
- |> fold_rev absfree abstractions
+ |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
+ |> fold_rev lambda abstractions
end;
fun build_defs lthy bs mxs funs_data rec_specs has_call =
@@ -372,15 +369,16 @@
type co_eqn_data_disc = {
fun_name: string,
+ fun_args: term list,
ctr_no: int, (*###*)
cond: term,
user_eqn: term
};
type co_eqn_data_sel = {
fun_name: string,
+ fun_args: term list,
ctr: term,
sel: term,
- fun_args: term list,
rhs_term: term,
user_eqn: term
};
@@ -388,11 +386,10 @@
Disc of co_eqn_data_disc |
Sel of co_eqn_data_sel;
-fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
+fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
let
fun find_subterm p = let (* FIXME \<exists>? *)
- fun f (t as u $ v) =
- fold_rev (curry merge_options) [if p t then SOME t else NONE, f u, f v] NONE
+ fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
| f t = if p t then SOME t else NONE
in f end;
@@ -406,9 +403,8 @@
val discs = #ctr_specs corec_spec |> map #disc;
val ctrs = #ctr_specs corec_spec |> map #ctr;
- val n_ctrs = length ctrs;
val not_disc = head_of imp_rhs = @{term Not};
- val _ = not_disc andalso n_ctrs <> 2 andalso
+ val _ = not_disc andalso length ctrs <> 2 andalso
primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" imp_rhs;
val disc = find_subterm (member (op =) discs o head_of) imp_rhs;
val eq_ctr0 = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
@@ -428,32 +424,28 @@
val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;
- val matched_conds = filter (equal fun_name o fst) matched_conds_ps |> map snd;
- val imp_lhs = mk_conjs imp_lhs';
+ val matched_cond = filter (equal fun_name o fst) matched_conds |> map snd |> mk_disjs;
+ val imp_lhs = mk_conjs imp_lhs'
+ |> incr_boundvars (length fun_args)
+ |> subst_atomic (fun_args ~~ map Bound (length fun_args - 1 downto 0))
val cond =
if catch_all then
- if null matched_conds then fold_rev absfree (map dest_Free fun_args) @{const True} else
- (strip_abs_vars (hd matched_conds),
- mk_disjs (map strip_abs_body matched_conds) |> HOLogic.mk_not)
- |-> fold_rev (fn (v, T) => fn u => Abs (v, T, u))
+ matched_cond |> HOLogic.mk_not
else if sequential then
- HOLogic.mk_conj (HOLogic.mk_not (mk_disjs (map strip_abs_body matched_conds)), imp_lhs)
- |> fold_rev absfree (map dest_Free fun_args)
+ HOLogic.mk_conj (HOLogic.mk_not matched_cond, imp_lhs)
else
- imp_lhs |> fold_rev absfree (map dest_Free fun_args);
- val matched_cond =
- if sequential then fold_rev absfree (map dest_Free fun_args) imp_lhs else cond;
+ imp_lhs;
- val matched_conds_ps' = if catch_all
- then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps
- else (fun_name, matched_cond) :: matched_conds_ps;
+ val matched_conds' =
+ (fun_name, if catch_all orelse not sequential then cond else imp_lhs) :: matched_conds;
in
(Disc {
fun_name = fun_name,
+ fun_args = fun_args,
ctr_no = ctr_no,
cond = cond,
user_eqn = eqn'
- }, matched_conds_ps')
+ }, matched_conds')
end;
fun co_dissect_eqn_sel fun_name_corec_spec_list eqn' eqn =
@@ -473,15 +465,15 @@
in
Sel {
fun_name = fun_name,
+ fun_args = fun_args,
ctr = #ctr ctr_spec,
sel = sel,
- fun_args = fun_args,
rhs_term = rhs,
user_eqn = eqn'
}
end;
-fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
+fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
let
val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
val fun_name = head_of lhs |> fst o dest_Free;
@@ -491,10 +483,10 @@
handle Option.Option => primrec_error_eqn "not a constructor" ctr;
val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
- val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1
- then (NONE, matched_conds_ps)
+ val (maybe_eqn_data_disc, matched_conds') = if length (#ctr_specs corec_spec) = 1
+ then (NONE, matched_conds)
else apfst SOME (co_dissect_eqn_disc
- sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps);
+ sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds);
val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
|> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
@@ -506,10 +498,10 @@
val eqns_data_sel =
map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
in
- (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps')
+ (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds')
end;
-fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps =
+fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds =
let
val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
strip_qnt_body @{const_name all} eqn')
@@ -531,65 +523,68 @@
if member (op =) discs head orelse
is_some maybe_rhs andalso
member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
- co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
+ co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
|>> single
else if member (op =) sels head then
- ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds_ps)
+ ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds)
else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
- co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps
+ co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
else
primrec_error_eqn "malformed function equation" eqn
end;
fun build_corec_args_discs disc_eqns ctr_specs =
- let
- val conds = map #cond disc_eqns;
- val args' =
- if length ctr_specs = 1 then []
- else if length disc_eqns = length ctr_specs then
- fst (split_last conds)
- else if length disc_eqns = length ctr_specs - 1 then
- let val n = 0 upto length ctr_specs - 1
- |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
- if n = length ctr_specs - 1 then
- conds
- else
- split_last conds
- ||> (fn t => fold_rev absfree (strip_abs_vars t) (strip_abs_body t |> HOLogic.mk_not))
- |>> chop n
- |> (fn ((l, r), x) => l @ (x :: r))
- end
- else
- 0 upto length ctr_specs - 1
- |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
- |> Option.map #cond
- |> the_default undef_const)
- |> fst o split_last;
- in
- (* FIXME: deal with #preds above *)
- fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
- end;
+ if null disc_eqns then I else
+ let
+ val conds = map #cond disc_eqns;
+ val fun_args = #fun_args (hd disc_eqns);
+ val args =
+ if length ctr_specs = 1 then []
+ else if length disc_eqns = length ctr_specs then
+ fst (split_last conds)
+ else if length disc_eqns = length ctr_specs - 1 then
+ let val n = 0 upto length ctr_specs - 1
+ |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in
+ if n = length ctr_specs - 1 then
+ conds
+ else
+ split_last conds
+ ||> HOLogic.mk_not
+ |>> chop n
+ |> (fn ((l, r), x) => l @ (x :: r))
+ end
+ else
+ 0 upto length ctr_specs - 1
+ |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
+ |> Option.map #cond
+ |> the_default undef_const)
+ |> fst o split_last;
+ in
+ (* FIXME deal with #preds above *)
+ (map_filter #pred ctr_specs, args)
+ |-> fold2 (fn idx => fn t => nth_map idx
+ (K (subst_bounds (List.rev fun_args, t)
+ |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args))))
+ end;
fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns
- |> try (fn SOME sel_eqn => (#fun_args sel_eqn |> map dest_Free, #rhs_term sel_eqn))
+ |> try (fn SOME sel_eqn => (#fun_args sel_eqn, #rhs_term sel_eqn))
|> the_default ([], undef_const)
- |-> abs_tuple oo fold_rev absfree;
+ |-> abs_tuple;
fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
let
val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns
-
- fun rewrite U T t =
+ fun rewrite is_end U T t =
if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *)
- else if T = U = has_call t then undef_const
- else if T = U then t (* end *)
+ else if is_end = has_call t then undef_const
+ else if is_end then t (* end *)
else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
- fun massage rhs_term t =
- massage_direct_corec_call lthy has_call rewrite [] (body_type (fastype_of t)) rhs_term;
- val abstract = abs_tuple oo fold_rev absfree o map dest_Free;
+ fun massage rhs_term is_end t = massage_direct_corec_call
+ lthy has_call (rewrite is_end) [] (range_type (fastype_of t)) rhs_term;
in
- if is_none maybe_sel_eqn then I else
- massage (#rhs_term (the maybe_sel_eqn)) #> abstract (#fun_args (the maybe_sel_eqn))
+ if is_none maybe_sel_eqn then K I else
+ abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
end;
fun build_corec_arg_indirect_call sel_eqns sel =
@@ -614,7 +609,7 @@
(build_corec_arg_no_call sel_eqns sel |> K)) no_calls'
#> fold (fn (sel, (q, g, h)) =>
let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
- nth_map h f o nth_map g f o nth_map q f end) direct_calls'
+ nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
#> fold (fn (sel, n) => nth_map n
(build_corec_arg_indirect_call sel_eqns sel |> K)) indirect_calls'
end
@@ -651,24 +646,25 @@
|> fold2 build_corec_args_discs disc_eqnss ctr_specss
|> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
+ fun currys Ts t = if length Ts <= 1 then t else
+ t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
+ (length Ts - 1 downto 0 |> map Bound)
+ |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
+
val _ = tracing ("corecursor arguments:\n \<cdot> " ^
space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));
fun uneq_pairs_rev xs = xs (* FIXME \<exists>? *)
|> these o try (split_last #> (fn (ys, y) => uneq_pairs_rev ys @ map (pair y) ys));
val proof_obligations = if sequential then [] else
- maps (uneq_pairs_rev o map #cond) disc_eqnss
- |> map (fn (x, y) => ((strip_abs_body x, strip_abs_body y), strip_abs_vars x))
- |> map (apfst (apsnd HOLogic.mk_not #> pairself HOLogic.mk_Trueprop
- #> apfst (curry (op $) @{const ==>}) #> (op $)))
- |> map (fn (t, abs_vars) => fold_rev (fn (v, T) => fn u =>
- Const (@{const_name all}, (T --> @{typ prop}) --> @{typ prop}) $
- Abs (v, T, u)) abs_vars t);
+ maps (uneq_pairs_rev o map (fn {fun_args, cond, ...} => (fun_args, cond))) disc_eqnss
+ |> map (fn ((fun_args, x), (_, y)) => [x, HOLogic.mk_not y]
+ |> map (HOLogic.mk_Trueprop o curry subst_bounds (List.rev fun_args))
+ |> curry list_comb @{const ==>});
- fun currys Ts t = if length Ts <= 1 then t else
- t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
- (length Ts - 1 downto 0 |> map Bound)
- |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
+val _ = tracing ("proof obligations:\n \<cdot> " ^
+ space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) proof_obligations));
+
in
map (list_comb o rpair corec_args) corecs
|> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss