--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML Wed Jan 08 09:20:14 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML Wed Jan 08 17:26:42 2014 +0100
@@ -93,6 +93,8 @@
fun unexpected_corec_call ctxt t =
error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
+fun order_list_duplicates xs = map snd (sort (int_ord o pairself fst) xs)
+
val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
val mk_dnf = mk_disjs o map mk_conjs;
@@ -474,12 +476,13 @@
fun_T: typ,
fun_args: term list,
ctr: term,
- ctr_no: int, (*FIXME*)
+ ctr_no: int,
disc: term,
prems: term list,
auto_gen: bool,
ctr_rhs_opt: term option,
code_rhs_opt: term option,
+ eqn_pos: int,
user_eqn: term
};
@@ -492,6 +495,7 @@
rhs_term: term,
ctr_rhs_opt: term option,
code_rhs_opt: term option,
+ eqn_pos: int,
user_eqn: term
};
@@ -500,7 +504,7 @@
Sel of coeqn_data_sel;
fun dissect_coeqn_disc fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list)
- ctr_rhs_opt code_rhs_opt prems' concl matchedsss =
+ eqn_pos ctr_rhs_opt code_rhs_opt prems' concl matchedsss =
let
fun find_subterm p =
let (* FIXME \<exists>? *)
@@ -558,12 +562,13 @@
auto_gen = catch_all,
ctr_rhs_opt = ctr_rhs_opt,
code_rhs_opt = code_rhs_opt,
+ eqn_pos = eqn_pos,
user_eqn = user_eqn
}, matchedsss')
end;
-fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) ctr_rhs_opt
- code_rhs_opt eqn' of_spec_opt eqn =
+fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn_pos
+ ctr_rhs_opt code_rhs_opt eqn' of_spec_opt eqn =
let
val (lhs, rhs) = HOLogic.dest_eq eqn
handle TERM _ =>
@@ -591,12 +596,13 @@
rhs_term = rhs,
ctr_rhs_opt = ctr_rhs_opt,
code_rhs_opt = code_rhs_opt,
+ eqn_pos = eqn_pos,
user_eqn = user_eqn
}
end;
-fun dissect_coeqn_ctr fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
- code_rhs_opt prems concl matchedsss =
+fun dissect_coeqn_ctr fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list)
+ eqn_pos eqn' code_rhs_opt prems concl matchedsss =
let
val (lhs, rhs) = HOLogic.dest_eq concl;
val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
@@ -608,7 +614,7 @@
val disc_concl = betapply (disc, lhs);
val (eqn_data_disc_opt, matchedsss') = if length basic_ctr_specs = 1
then (NONE, matchedsss)
- else apfst SOME (dissect_coeqn_disc fun_names sequentials basic_ctr_specss
+ else apfst SOME (dissect_coeqn_disc fun_names sequentials basic_ctr_specss eqn_pos
(SOME (abstract (List.rev fun_args) rhs)) code_rhs_opt prems disc_concl matchedsss);
val sel_concls = sels ~~ ctr_args
@@ -623,13 +629,13 @@
*)
val eqns_data_sel =
- map (dissect_coeqn_sel fun_names basic_ctr_specss
+ map (dissect_coeqn_sel fun_names basic_ctr_specss eqn_pos
(SOME (abstract (List.rev fun_args) rhs)) code_rhs_opt eqn' (SOME ctr)) sel_concls;
in
(the_list eqn_data_disc_opt @ eqns_data_sel, matchedsss')
end;
-fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss =
+fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn_pos eqn' concl matchedsss =
let
val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
@@ -651,13 +657,13 @@
val sequentials = replicate (length fun_names) false;
in
- fold_map2 (dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn'
+ fold_map2 (dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn_pos eqn'
(SOME (abstract (List.rev fun_args) rhs)))
ctr_premss ctr_concls matchedsss
end;
fun dissect_coeqn lthy has_call fun_names sequentials
- (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' of_spec_opt matchedsss =
+ (basic_ctr_specss : basic_corec_ctr_spec list list) (eqn_pos, eqn') of_spec_opt matchedsss =
let
val eqn = drop_All eqn'
handle TERM _ => primcorec_error_eqn "malformed function equation" eqn';
@@ -677,17 +683,17 @@
if member (op =) discs head orelse
is_some rhs_opt andalso
member (op =) (filter (null o binder_types o fastype_of) ctrs) (the rhs_opt) then
- dissect_coeqn_disc fun_names sequentials basic_ctr_specss NONE NONE prems concl matchedsss
+ dissect_coeqn_disc fun_names sequentials basic_ctr_specss eqn_pos NONE NONE prems concl matchedsss
|>> single
else if member (op =) sels head then
- ([dissect_coeqn_sel fun_names basic_ctr_specss NONE NONE eqn' of_spec_opt concl],
+ ([dissect_coeqn_sel fun_names basic_ctr_specss eqn_pos NONE NONE eqn' of_spec_opt concl],
matchedsss)
else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
member (op =) ctrs (head_of (unfold_let (the rhs_opt))) then
- dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn' NONE prems concl matchedsss
+ dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn_pos eqn' NONE prems concl matchedsss
else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
null prems then
- dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss
+ dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn_pos eqn' concl matchedsss
|>> flat
else
primcorec_error_eqn "malformed function equation" eqn
@@ -834,6 +840,7 @@
auto_gen = true,
ctr_rhs_opt = Option.map #ctr_rhs_opt sel_eqn_opt |> the_default NONE,
code_rhs_opt = Option.map #ctr_rhs_opt sel_eqn_opt |> the_default NONE,
+ eqn_pos = Option.map (curry (op +) 1 o #eqn_pos) sel_eqn_opt |> the_default 1000 (*###*),
user_eqn = undef_const};
in
chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
@@ -877,7 +884,7 @@
val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
val eqns_data =
- fold_map2 (dissect_coeqn lthy has_call fun_names sequentials basic_ctr_specss) (map snd specs)
+ fold_map2 (dissect_coeqn lthy has_call fun_names sequentials basic_ctr_specss) (tag_list 0 (map snd specs))
of_specs_opt []
|> flat o fst;
@@ -897,7 +904,7 @@
val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms,
strong_coinduct_thms), lthy') =
corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
- val corec_specs = take actual_nn corec_specs'; (*FIXME*)
+ val corec_specs = take actual_nn corec_specs';
val ctr_specss = map #ctr_specs corec_specs;
val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
@@ -1011,7 +1018,7 @@
mk_excludesss excludes (length ctr_specs));
fun prove_disc ({ctr_specs, ...} : corec_spec) excludesss
- ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
+ ({fun_name, fun_T, fun_args, ctr_no, prems, eqn_pos, ...} : coeqn_data_disc) =
if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then
[]
else
@@ -1033,12 +1040,13 @@
|> K |> Goal.prove lthy [] [] goal
|> Thm.close_derivation
|> pair (#disc (nth ctr_specs ctr_no))
+ |> pair eqn_pos
|> single
end;
fun prove_sel ({nested_map_idents, nested_map_comps, ctr_specs, ...} : corec_spec)
(disc_eqns : coeqn_data_disc list) excludesss
- ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) =
+ ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, eqn_pos, ...} : coeqn_data_sel) =
let
val SOME ctr_spec = find_first (curry (op =) ctr o #ctr) ctr_specs;
val ctr_no = find_index (curry (op =) ctr o #ctr) ctr_specs;
@@ -1062,6 +1070,7 @@
|> K |> Goal.prove lthy [] [] goal
|> Thm.close_derivation
|> pair sel
+ |> pair eqn_pos
end;
fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list)
@@ -1075,12 +1084,13 @@
|> exists (null o snd)
then [] else
let
- val (fun_name, fun_T, fun_args, prems, rhs_opt) =
+ val (fun_name, fun_T, fun_args, prems, rhs_opt, eqn_pos) =
(find_first (curry (op =) ctr o #ctr) disc_eqns,
find_first (curry (op =) ctr o #ctr) sel_eqns)
|>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x,
- #ctr_rhs_opt x))
- ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #ctr_rhs_opt x))
+ #ctr_rhs_opt x, #eqn_pos x))
+ ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #ctr_rhs_opt x,
+ #eqn_pos x))
|> the o merge_options;
val m = length prems;
val goal =
@@ -1102,6 +1112,7 @@
|> K |> Goal.prove lthy [] [] goal
|> Thm.close_derivation
|> pair ctr
+ |> pair eqn_pos
|> single
end;
@@ -1149,7 +1160,7 @@
val ctr_conds_argss_opt = map prove_code_ctr ctr_specs;
val exhaustive_code =
exhaustive
- orelse forall null (map_filter (try (fst o the)) ctr_conds_argss_opt)
+ orelse exists (is_some andf (null o fst o the)) ctr_conds_argss_opt
orelse forall is_some ctr_conds_argss_opt
andalso exists #auto_gen disc_eqns;
val rhs =
@@ -1192,14 +1203,15 @@
end;
val disc_alistss = map3 (map oo prove_disc) corec_specs excludessss disc_eqnss;
- val disc_alists = map flat disc_alistss;
+ val disc_alists = map (map snd o flat) disc_alistss;
val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss excludessss sel_eqnss;
- val disc_thmsss = map (map (map snd)) disc_alistss;
- val disc_thmss = map flat disc_thmsss;
- val sel_thmss = map (map snd) sel_alists;
+ val disc_thmss = map (map snd o order_list_duplicates o flat) disc_alistss;
+ val disc_thmsss' = map (map (map (snd o snd))) disc_alistss;
+ val disc_thmss' = map flat disc_thmsss';
+ val sel_thmss = map (map snd o order_list_duplicates) sel_alists;
- fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss disc_thms
- ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
+ fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss' disc_thms
+ ({fun_name, fun_T, fun_args, ctr_no, prems, eqn_pos, ...} : coeqn_data_disc) =
if null disc_thms orelse null exhaust_thms then
[]
else
@@ -1214,23 +1226,26 @@
[]
else
mk_primcorec_disc_iff_tac lthy (map (fst o dest_Free) fun_args)
- (the_single exhaust_thms) (the_single disc_thms) disc_thmss (flat disc_excludess)
+ (the_single exhaust_thms) (the_single disc_thms) disc_thmss' (flat disc_excludess)
|> K |> Goal.prove lthy [] [] goal
|> Thm.close_derivation
+ |> pair eqn_pos
|> single
end;
val disc_iff_thmss = map5 (flat ooo map2 ooo prove_disc_iff) corec_specs exhaust_thmss
- disc_thmsss disc_thmsss disc_eqnss;
+ disc_thmsss' disc_thmsss' disc_eqnss
+ |> map order_list_duplicates;
- val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
- ctr_specss;
- val ctr_thmss = map (map snd) ctr_alists;
+ val ctr_alists = map5 (maps oooo prove_ctr) disc_alists (map (map snd) sel_alists) disc_eqnss
+ sel_eqnss ctr_specss;
+ val ctr_thmss' = map (map snd) ctr_alists;
+ val ctr_thmss = map (map snd o order_list) ctr_alists;
- val code_thmss = map6 prove_code exhaustives disc_eqnss sel_eqnss nchotomy_thmss ctr_alists
+ val code_thmss = map6 prove_code exhaustives disc_eqnss sel_eqnss nchotomy_thmss ctr_thmss'
ctr_specss;
- val simp_thmss = map2 append disc_thmss sel_thmss
+ val simp_thmss = map2 append disc_thmss sel_thmss;
val common_name = mk_common_name fun_names;