# HG changeset patch # User panny # Date 1389198402 -3600 # Node ID 516adecd99dd5087b1698c046ad406d8c6cf07b9 # Parent 9a52ee8cae9bc791356504e1817339c39fde56a8 match order of generated theorems to user input; improve exhaustiveness criterion to prevent exception diff -r 9a52ee8cae9b -r 516adecd99dd src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML --- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML Wed Jan 08 09:20:14 2014 +0100 +++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML Wed Jan 08 17:26:42 2014 +0100 @@ -93,6 +93,8 @@ fun unexpected_corec_call ctxt t = error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t)); +fun order_list_duplicates xs = map snd (sort (int_ord o pairself fst) xs) + val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True}; val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False}; val mk_dnf = mk_disjs o map mk_conjs; @@ -474,12 +476,13 @@ fun_T: typ, fun_args: term list, ctr: term, - ctr_no: int, (*FIXME*) + ctr_no: int, disc: term, prems: term list, auto_gen: bool, ctr_rhs_opt: term option, code_rhs_opt: term option, + eqn_pos: int, user_eqn: term }; @@ -492,6 +495,7 @@ rhs_term: term, ctr_rhs_opt: term option, code_rhs_opt: term option, + eqn_pos: int, user_eqn: term }; @@ -500,7 +504,7 @@ Sel of coeqn_data_sel; fun dissect_coeqn_disc fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list) - ctr_rhs_opt code_rhs_opt prems' concl matchedsss = + eqn_pos ctr_rhs_opt code_rhs_opt prems' concl matchedsss = let fun find_subterm p = let (* FIXME \? *) @@ -558,12 +562,13 @@ auto_gen = catch_all, ctr_rhs_opt = ctr_rhs_opt, code_rhs_opt = code_rhs_opt, + eqn_pos = eqn_pos, user_eqn = user_eqn }, matchedsss') end; -fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) ctr_rhs_opt - code_rhs_opt eqn' of_spec_opt eqn = +fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn_pos + ctr_rhs_opt code_rhs_opt eqn' of_spec_opt eqn = let val (lhs, rhs) = HOLogic.dest_eq eqn handle TERM _ => @@ -591,12 +596,13 @@ rhs_term = rhs, ctr_rhs_opt = ctr_rhs_opt, code_rhs_opt = code_rhs_opt, + eqn_pos = eqn_pos, user_eqn = user_eqn } end; -fun dissect_coeqn_ctr fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' - code_rhs_opt prems concl matchedsss = +fun dissect_coeqn_ctr fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list) + eqn_pos eqn' code_rhs_opt prems concl matchedsss = let val (lhs, rhs) = HOLogic.dest_eq concl; val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; @@ -608,7 +614,7 @@ val disc_concl = betapply (disc, lhs); val (eqn_data_disc_opt, matchedsss') = if length basic_ctr_specs = 1 then (NONE, matchedsss) - else apfst SOME (dissect_coeqn_disc fun_names sequentials basic_ctr_specss + else apfst SOME (dissect_coeqn_disc fun_names sequentials basic_ctr_specss eqn_pos (SOME (abstract (List.rev fun_args) rhs)) code_rhs_opt prems disc_concl matchedsss); val sel_concls = sels ~~ ctr_args @@ -623,13 +629,13 @@ *) val eqns_data_sel = - map (dissect_coeqn_sel fun_names basic_ctr_specss + map (dissect_coeqn_sel fun_names basic_ctr_specss eqn_pos (SOME (abstract (List.rev fun_args) rhs)) code_rhs_opt eqn' (SOME ctr)) sel_concls; in (the_list eqn_data_disc_opt @ eqns_data_sel, matchedsss') end; -fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss = +fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn_pos eqn' concl matchedsss = let val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []); val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; @@ -651,13 +657,13 @@ val sequentials = replicate (length fun_names) false; in - fold_map2 (dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn' + fold_map2 (dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn_pos eqn' (SOME (abstract (List.rev fun_args) rhs))) ctr_premss ctr_concls matchedsss end; fun dissect_coeqn lthy has_call fun_names sequentials - (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' of_spec_opt matchedsss = + (basic_ctr_specss : basic_corec_ctr_spec list list) (eqn_pos, eqn') of_spec_opt matchedsss = let val eqn = drop_All eqn' handle TERM _ => primcorec_error_eqn "malformed function equation" eqn'; @@ -677,17 +683,17 @@ if member (op =) discs head orelse is_some rhs_opt andalso member (op =) (filter (null o binder_types o fastype_of) ctrs) (the rhs_opt) then - dissect_coeqn_disc fun_names sequentials basic_ctr_specss NONE NONE prems concl matchedsss + dissect_coeqn_disc fun_names sequentials basic_ctr_specss eqn_pos NONE NONE prems concl matchedsss |>> single else if member (op =) sels head then - ([dissect_coeqn_sel fun_names basic_ctr_specss NONE NONE eqn' of_spec_opt concl], + ([dissect_coeqn_sel fun_names basic_ctr_specss eqn_pos NONE NONE eqn' of_spec_opt concl], matchedsss) else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso member (op =) ctrs (head_of (unfold_let (the rhs_opt))) then - dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn' NONE prems concl matchedsss + dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn_pos eqn' NONE prems concl matchedsss else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso null prems then - dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss + dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn_pos eqn' concl matchedsss |>> flat else primcorec_error_eqn "malformed function equation" eqn @@ -834,6 +840,7 @@ auto_gen = true, ctr_rhs_opt = Option.map #ctr_rhs_opt sel_eqn_opt |> the_default NONE, code_rhs_opt = Option.map #ctr_rhs_opt sel_eqn_opt |> the_default NONE, + eqn_pos = Option.map (curry (op +) 1 o #eqn_pos) sel_eqn_opt |> the_default 1000 (*###*), user_eqn = undef_const}; in chop n disc_eqns ||> cons extra_disc_eqn |> (op @) @@ -877,7 +884,7 @@ val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts; val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); val eqns_data = - fold_map2 (dissect_coeqn lthy has_call fun_names sequentials basic_ctr_specss) (map snd specs) + fold_map2 (dissect_coeqn lthy has_call fun_names sequentials basic_ctr_specss) (tag_list 0 (map snd specs)) of_specs_opt [] |> flat o fst; @@ -897,7 +904,7 @@ val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms, strong_coinduct_thms), lthy') = corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; - val corec_specs = take actual_nn corec_specs'; (*FIXME*) + val corec_specs = take actual_nn corec_specs'; val ctr_specss = map #ctr_specs corec_specs; val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data @@ -1011,7 +1018,7 @@ mk_excludesss excludes (length ctr_specs)); fun prove_disc ({ctr_specs, ...} : corec_spec) excludesss - ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) = + ({fun_name, fun_T, fun_args, ctr_no, prems, eqn_pos, ...} : coeqn_data_disc) = if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\x. x = x"}) then [] else @@ -1033,12 +1040,13 @@ |> K |> Goal.prove lthy [] [] goal |> Thm.close_derivation |> pair (#disc (nth ctr_specs ctr_no)) + |> pair eqn_pos |> single end; fun prove_sel ({nested_map_idents, nested_map_comps, ctr_specs, ...} : corec_spec) (disc_eqns : coeqn_data_disc list) excludesss - ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) = + ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, eqn_pos, ...} : coeqn_data_sel) = let val SOME ctr_spec = find_first (curry (op =) ctr o #ctr) ctr_specs; val ctr_no = find_index (curry (op =) ctr o #ctr) ctr_specs; @@ -1062,6 +1070,7 @@ |> K |> Goal.prove lthy [] [] goal |> Thm.close_derivation |> pair sel + |> pair eqn_pos end; fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list) @@ -1075,12 +1084,13 @@ |> exists (null o snd) then [] else let - val (fun_name, fun_T, fun_args, prems, rhs_opt) = + val (fun_name, fun_T, fun_args, prems, rhs_opt, eqn_pos) = (find_first (curry (op =) ctr o #ctr) disc_eqns, find_first (curry (op =) ctr o #ctr) sel_eqns) |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x, - #ctr_rhs_opt x)) - ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #ctr_rhs_opt x)) + #ctr_rhs_opt x, #eqn_pos x)) + ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #ctr_rhs_opt x, + #eqn_pos x)) |> the o merge_options; val m = length prems; val goal = @@ -1102,6 +1112,7 @@ |> K |> Goal.prove lthy [] [] goal |> Thm.close_derivation |> pair ctr + |> pair eqn_pos |> single end; @@ -1149,7 +1160,7 @@ val ctr_conds_argss_opt = map prove_code_ctr ctr_specs; val exhaustive_code = exhaustive - orelse forall null (map_filter (try (fst o the)) ctr_conds_argss_opt) + orelse exists (is_some andf (null o fst o the)) ctr_conds_argss_opt orelse forall is_some ctr_conds_argss_opt andalso exists #auto_gen disc_eqns; val rhs = @@ -1192,14 +1203,15 @@ end; val disc_alistss = map3 (map oo prove_disc) corec_specs excludessss disc_eqnss; - val disc_alists = map flat disc_alistss; + val disc_alists = map (map snd o flat) disc_alistss; val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss excludessss sel_eqnss; - val disc_thmsss = map (map (map snd)) disc_alistss; - val disc_thmss = map flat disc_thmsss; - val sel_thmss = map (map snd) sel_alists; + val disc_thmss = map (map snd o order_list_duplicates o flat) disc_alistss; + val disc_thmsss' = map (map (map (snd o snd))) disc_alistss; + val disc_thmss' = map flat disc_thmsss'; + val sel_thmss = map (map snd o order_list_duplicates) sel_alists; - fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss disc_thms - ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) = + fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss' disc_thms + ({fun_name, fun_T, fun_args, ctr_no, prems, eqn_pos, ...} : coeqn_data_disc) = if null disc_thms orelse null exhaust_thms then [] else @@ -1214,23 +1226,26 @@ [] else mk_primcorec_disc_iff_tac lthy (map (fst o dest_Free) fun_args) - (the_single exhaust_thms) (the_single disc_thms) disc_thmss (flat disc_excludess) + (the_single exhaust_thms) (the_single disc_thms) disc_thmss' (flat disc_excludess) |> K |> Goal.prove lthy [] [] goal |> Thm.close_derivation + |> pair eqn_pos |> single end; val disc_iff_thmss = map5 (flat ooo map2 ooo prove_disc_iff) corec_specs exhaust_thmss - disc_thmsss disc_thmsss disc_eqnss; + disc_thmsss' disc_thmsss' disc_eqnss + |> map order_list_duplicates; - val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss - ctr_specss; - val ctr_thmss = map (map snd) ctr_alists; + val ctr_alists = map5 (maps oooo prove_ctr) disc_alists (map (map snd) sel_alists) disc_eqnss + sel_eqnss ctr_specss; + val ctr_thmss' = map (map snd) ctr_alists; + val ctr_thmss = map (map snd o order_list) ctr_alists; - val code_thmss = map6 prove_code exhaustives disc_eqnss sel_eqnss nchotomy_thmss ctr_alists + val code_thmss = map6 prove_code exhaustives disc_eqnss sel_eqnss nchotomy_thmss ctr_thmss' ctr_specss; - val simp_thmss = map2 append disc_thmss sel_thmss + val simp_thmss = map2 append disc_thmss sel_thmss; val common_name = mk_common_name fun_names;