# HG changeset patch # User bulwahn # Date 1249367696 -7200 # Node ID 4a18f3cf63626247bd9884b4c36e2dba587e91fa # Parent c2b74affab85d1d7d499e72c0611217d85fede2d imported patch changed mode inference of predicate compiler to return infered dataflow diff -r c2b74affab85 -r 4a18f3cf6362 src/HOL/ex/predicate_compile.ML --- a/src/HOL/ex/predicate_compile.ML Tue Aug 04 08:34:56 2009 +0200 +++ b/src/HOL/ex/predicate_compile.ML Tue Aug 04 08:34:56 2009 +0200 @@ -44,9 +44,11 @@ val prepare_intrs: theory -> string list -> (string * typ) list * int * string list * string list * (string * mode list) list * (string * (term list * indprem list) list) list * (string * (int option list * int)) list + datatype tmode = Mode of mode * int list * tmode option list; val infer_modes : theory -> (string * (int list option list * int list) list) list -> (string * (int option list * int)) list -> string list - -> (string * (term list * indprem list) list) list -> (string * mode list) list + -> (string * (term list * indprem list) list) list + -> (string * (mode * ((term list * (indprem * tmode) list) list)) list) list val split_mode : int list -> term list -> (term list * term list) end; @@ -221,6 +223,7 @@ (* data structures *) type mode = int list option list * int list; (*pmode FIMXE*) +datatype tmode = Mode of mode * int list * tmode option list; fun string_of_mode (iss, is) = space_implode " -> " (map (fn NONE => "X" @@ -547,8 +550,6 @@ fun cprods xss = foldr (map op :: o cprod) [[]] xss; -datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand - why there is another mode type tmode !?*) (*TODO: cleanup function and put together with modes_of_term *) @@ -664,46 +665,57 @@ val modes' = modes @ List.mapPartial (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) (param_vs ~~ iss); - fun check_mode_prems vs [] = SOME vs - | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of + fun check_mode_prems acc_ps vs [] = SOME (acc_ps, vs) + | check_mode_prems acc_ps vs ps = (case select_mode_prem thy modes' vs ps of NONE => NONE - | SOME (x, _) => check_mode_prems - (case x of Prem (us, _) => vs union terms_vs us | _ => vs) - (filter_out (equal x) ps)) + | SOME (p, SOME mode) => check_mode_prems ((p, mode) :: acc_ps) + (case p of Prem (us, _) => vs union terms_vs us | _ => vs) + (filter_out (equal p) ps)) val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (split_mode is ts)); val in_vs = terms_vs in_ts; val concl_vs = terms_vs ts - val _ = Output.tracing ("ts :" ^ (commas (map (Syntax.string_of_term_global thy) ts))) - val _ = () - val ret = forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso - forall (is_eqT o fastype_of) in_ts' andalso - (case check_mode_prems (param_vs union in_vs) ps of - NONE => false - | SOME vs => concl_vs subset vs) in - ret + if forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso + forall (is_eqT o fastype_of) in_ts' then + case check_mode_prems [] (param_vs union in_vs) ps of + NONE => NONE + | SOME (acc_ps, vs) => if concl_vs subset vs then SOME (ts, rev acc_ps) else NONE + else NONE end; fun check_modes_pred thy param_vs preds modes (p, ms) = let val SOME rs = AList.lookup (op =) preds p in (p, List.filter (fn m => case find_index - (not o check_mode_clause thy param_vs modes m) rs of + (is_none o check_mode_clause thy param_vs modes m) rs of ~1 => true - | i => (Output.tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^ + | i => (Output.tracing ("Clause " ^ string_of_int (i + 1) ^ " of " ^ p ^ " violates mode " ^ string_of_mode m); false)) ms) end; +fun get_modes_pred thy param_vs preds modes (p, ms) = + let + val SOME rs = AList.lookup (op =) preds p + in + (p, map (fn m => (m, map (the o check_mode_clause thy param_vs modes m) rs)) ms) + end; + fun fixp f (x : (string * mode list) list) = let val y = f x in if x = y then x else fixp f y end; -fun infer_modes thy extra_modes arities param_vs preds = fixp (fn modes => - map (check_modes_pred thy param_vs preds (modes @ extra_modes)) modes) - (map (fn (s, (ks, k)) => (s, cprod (cprods (map - (fn NONE => [NONE] - | SOME k' => map SOME (subsets 1 k')) ks), - subsets 1 k))) arities); - +fun infer_modes thy extra_modes arities param_vs preds = + let + val modes = + fixp (fn modes => + map (check_modes_pred thy param_vs preds (modes @ extra_modes)) modes) + (map (fn (s, (ks, k)) => (s, cprod (cprods (map + (fn NONE => [NONE] + | SOME k' => map SOME (subsets 1 k')) ks), + subsets 1 k))) arities) + in + map (get_modes_pred thy param_vs preds (modes @ extra_modes)) modes + end; + (* term construction *) fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of @@ -809,7 +821,7 @@ | compile_param _ _ _ _ = error "compile params" -fun compile_expr thy compfuns modes (SOME (Mode (mode, is, ms)), t) = +fun compile_expr thy compfuns modes ((Mode (mode, is, ms)), t) = (case strip_comb t of (Const (name, T), params) => if AList.defined op = modes name then @@ -827,7 +839,7 @@ | compile_expr _ _ _ _ = error "not a valid inductive expression" -fun compile_clause thy compfuns all_vs param_vs modes (iss, is) (ts, ps) inp = +fun compile_clause thy compfuns all_vs param_vs modes (iss, is) inp (ts, moded_ps) = let val modes' = modes @ List.mapPartial (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) @@ -853,12 +865,9 @@ compile_match thy compfuns constr_vs (eqs @ eqs') out_ts''' (mk_single compfuns (mk_tuple out_ts)) end - | compile_prems out_ts vs names ps = + | compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) = let val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); - val SOME (p, mode as SOME (Mode (_, js, _))) = - select_mode_prem thy modes' vs' ps - val ps' = filter_out (equal p) ps val (out_ts', (names', eqs)) = fold_map check_constrt out_ts (names, []) val (out_ts'', (names'', constr_vs')) = fold_map distinct_v @@ -866,23 +875,23 @@ val (compiled_clause, rest) = case p of Prem (us, t) => let - val (in_ts, out_ts''') = split_mode js us; + val (in_ts, out_ts''') = split_mode is us; val u = list_comb (compile_expr thy compfuns modes (mode, t), in_ts) - val rest = compile_prems out_ts''' vs' names'' ps' + val rest = compile_prems out_ts''' vs' names'' ps in (u, rest) end | Negprem (us, t) => let - val (in_ts, out_ts''') = split_mode js us + val (in_ts, out_ts''') = split_mode is us val u = list_comb (compile_expr thy compfuns modes (mode, t), in_ts) - val rest = compile_prems out_ts''' vs' names'' ps' + val rest = compile_prems out_ts''' vs' names'' ps in (mk_not compfuns u, rest) end | Sidecond t => let - val rest = compile_prems [] vs' names'' ps'; + val rest = compile_prems [] vs' names'' ps; in (mk_if compfuns t, rest) end @@ -890,12 +899,12 @@ compile_match thy compfuns constr_vs' eqs out_ts'' (mk_bind compfuns (compiled_clause, rest)) end - val prem_t = compile_prems in_ts' param_vs all_vs' ps; + val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps; in mk_bind compfuns (mk_single compfuns inp, prem_t) end -fun compile_pred thy compfuns all_vs param_vs modes s T cls mode = +fun compile_pred thy compfuns all_vs param_vs modes s T mode moded_cls = let val Ts = binder_types T; val (Ts1, Ts2) = chop (length param_vs) Ts; @@ -905,8 +914,8 @@ (map (fn i => "x" ^ string_of_int i) (snd mode)); val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1; val cl_ts = - map (fn cl => compile_clause thy compfuns - all_vs param_vs modes mode cl (mk_tuple xs)) cls; + map (compile_clause thy compfuns + all_vs param_vs modes mode (mk_tuple xs)) moded_cls; in HOLogic.mk_Trueprop (HOLogic.mk_eq (list_comb (mk_predfun_of thy compfuns (s, T) mode, @@ -914,11 +923,21 @@ foldr1 (mk_sup compfuns) cl_ts)) end; -fun compile_preds thy compfuns all_vs param_vs modes preds = - map (fn (s, (T, cls)) => - map (compile_pred thy compfuns all_vs param_vs modes s T cls) - ((the o AList.lookup (op =) modes) s)) preds; +fun map_preds_modes f preds_modes_table = + map (fn (pred, modes) => + (pred, map (fn (mode, value) => (mode, f pred mode value)) modes)) preds_modes_table +fun join_preds_modes table1 table2 = + map_preds_modes (fn pred => fn mode => fn value => + (value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1 + +fun maps_modes preds_modes_table = + map (fn (pred, modes) => + (pred, map (fn (mode, value) => value) modes)) preds_modes_table + +fun compile_preds thy compfuns all_vs param_vs modes preds moded_clauses = + map_preds_modes (fn pred => compile_pred thy compfuns all_vs param_vs modes pred + (the (AList.lookup (op =) preds pred))) moded_clauses (* special setup for simpset *) val HOL_basic_ss' = HOL_basic_ss setSolver @@ -1076,9 +1095,9 @@ (* MAJOR FIXME: prove_params should be simple - different form of introrule for parameters ? *) -fun prove_param thy modes (NONE, t) = +fun prove_param thy (NONE, t) = all_tac -| prove_param thy modes (m as SOME (Mode (mode, is, ms)), t) = +| prove_param thy (m as SOME (Mode (mode, is, ms)), t) = REPEAT_DETERM (etac @{thm thin_rl} 1) THEN REPEAT_DETERM (rtac @{thm ext} 1) THEN (rtac @{thm iffI} 1) @@ -1105,21 +1124,20 @@ THEN (REPEAT_DETERM (atac 1)) end *) -fun prove_expr thy modes (SOME (Mode (mode, is, ms)), t, us) (premposition : int) = - (case strip_comb t of - (Const (name, T), args) => - if AList.defined op = modes name then (let - val introrule = predfun_intro_of thy name mode - (*val (in_args, out_args) = split_mode is us - val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop - (hd (Logic.strip_imp_prems (prop_of introrule)))) - val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *) - val (_, args) = chop nparams rargs - val subst = map (pairself (cterm_of thy)) (args ~~ us) - val inst_introrule = Drule.cterm_instantiate subst introrule*) - (* the next line is old and probably wrong *) - val (args1, args2) = chop (length ms) args - in +fun prove_expr thy (Mode (mode, is, ms), t, us) (premposition : int) = + case strip_comb t of + (Const (name, T), args) => + let + val introrule = predfun_intro_of thy name mode + (*val (in_args, out_args) = split_mode is us + val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop + (hd (Logic.strip_imp_prems (prop_of introrule)))) + val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *) + val (_, args) = chop nparams rargs + val subst = map (pairself (cterm_of thy)) (args ~~ us) + val inst_introrule = Drule.cterm_instantiate subst introrule*) + val (args1, args2) = chop (length ms) args + in rtac @{thm bindI} 1 THEN print_tac "before intro rule:" (* for the right assumption in first position *) @@ -1129,12 +1147,10 @@ (* work with parameter arguments *) THEN (atac 1) THEN (print_tac "parameter goal") - THEN (EVERY (map (prove_param thy modes) (ms ~~ args1))) - THEN (REPEAT_DETERM (atac 1)) end) - else error "Prove expr if case not implemented" - | _ => rtac @{thm bindI} 1 - THEN atac 1) - | prove_expr _ _ _ _ = error "Prove expr not implemented" + THEN (EVERY (map (prove_param thy) (ms ~~ args1))) + THEN (REPEAT_DETERM (atac 1)) + end + | _ => rtac @{thm bindI} 1 THEN atac 1 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; @@ -1178,101 +1194,85 @@ (* need better control here! *) end -fun prove_clause thy nargs all_vs param_vs modes (iss, is) (ts, ps) = let - val modes' = modes @ List.mapPartial - (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) - (param_vs ~~ iss); - fun check_constrt ((names, eqs), t) = - if is_constrt thy t then ((names, eqs), t) else - let - val s = Name.variant names "x"; - val v = Free (s, fastype_of t) - in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end; - - val (in_ts, clause_out_ts) = split_mode is ts; - val ((all_vs', eqs), in_ts') = - (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts); - fun prove_prems out_ts vs [] = - (prove_match thy out_ts) - THEN asm_simp_tac HOL_basic_ss' 1 - THEN print_tac "before the last rule of singleI:" - THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1) - | prove_prems out_ts vs rps = - let - val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); - val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) = - select_mode_prem thy modes' vs' rps; - val premposition = (find_index (equal p) ps) + nargs - val rps' = filter_out (equal p) rps; - val rest_tac = (case p of Prem (us, t) => - let - val (in_ts, out_ts''') = split_mode js us - val rec_tac = prove_prems out_ts''' vs' rps' - in - print_tac "before clause:" - THEN asm_simp_tac HOL_basic_ss 1 - THEN print_tac "before prove_expr:" - THEN prove_expr thy modes (mode, t, us) premposition - THEN print_tac "after prove_expr:" - THEN rec_tac - end - | Negprem (us, t) => - let - val (in_ts, out_ts''') = split_mode js us - val rec_tac = prove_prems out_ts''' vs' rps' - val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) - val (_, params) = strip_comb t - in - rtac @{thm bindI} 1 - THEN (if (is_some name) then - simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, js)]) 1 - THEN rtac @{thm not_predI} 1 - (* FIXME: work with parameter arguments *) - THEN (EVERY (map (prove_param thy modes) (param_modes ~~ params))) - else - rtac @{thm not_predI'} 1) - THEN (REPEAT_DETERM (atac 1)) - THEN rec_tac - end - | Sidecond t => - rtac @{thm bindI} 1 - THEN rtac @{thm if_predI} 1 - THEN print_tac "before sidecond:" - THEN prove_sidecond thy modes t - THEN print_tac "after sidecond:" - THEN prove_prems [] vs' rps') - in (prove_match thy out_ts) - THEN rest_tac - end; - val prems_tac = prove_prems in_ts' param_vs ps -in - rtac @{thm bindI} 1 - THEN rtac @{thm singleI} 1 - THEN prems_tac -end; +fun prove_clause thy nargs modes (iss, is) (_, clauses) (ts, moded_ps) = + let + val (in_ts, clause_out_ts) = split_mode is ts; + fun prove_prems out_ts [] = + (prove_match thy out_ts) + THEN asm_simp_tac HOL_basic_ss' 1 + THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1) + | prove_prems out_ts ((p, mode as Mode ((iss, is), _, param_modes)) :: ps) = + let + val premposition = (find_index (equal p) clauses) + nargs + val rest_tac = (case p of Prem (us, t) => + let + val (_, out_ts''') = split_mode is us + val rec_tac = prove_prems out_ts''' ps + in + print_tac "before clause:" + THEN asm_simp_tac HOL_basic_ss 1 + THEN print_tac "before prove_expr:" + THEN prove_expr thy (mode, t, us) premposition + THEN print_tac "after prove_expr:" + THEN rec_tac + end + | Negprem (us, t) => + let + val (_, out_ts''') = split_mode is us + val rec_tac = prove_prems out_ts''' ps + val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) + val (_, params) = strip_comb t + in + rtac @{thm bindI} 1 + THEN (if (is_some name) then + simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, is)]) 1 + THEN rtac @{thm not_predI} 1 + (* FIXME: work with parameter arguments *) + THEN (EVERY (map (prove_param thy) (param_modes ~~ params))) + else + rtac @{thm not_predI'} 1) + THEN (REPEAT_DETERM (atac 1)) + THEN rec_tac + end + | Sidecond t => + rtac @{thm bindI} 1 + THEN rtac @{thm if_predI} 1 + THEN print_tac "before sidecond:" + THEN prove_sidecond thy modes t + THEN print_tac "after sidecond:" + THEN prove_prems [] ps) + in (prove_match thy out_ts) + THEN rest_tac + end; + val prems_tac = prove_prems in_ts moded_ps + in + rtac @{thm bindI} 1 + THEN rtac @{thm singleI} 1 + THEN prems_tac + end; fun select_sup 1 1 = [] | select_sup _ 1 = [rtac @{thm supI1}] | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1)); -fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let -(* val ind_result = Inductive.the_inductive (ProofContext.init thy) pred - val index = find_index (fn s => s = pred) (#names (fst ind_result)) - val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *) - val nargs = length (binder_types T) - nparams_of thy pred - val pred_case_rule = singleton (ind_set_codegen_preproc thy) - (preprocess_elim thy nargs (the_elim_of thy pred)) - (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@{thm Predicate.memb_code}])*) -in - REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})) - THEN etac (predfun_elim_of thy pred mode) 1 - THEN etac pred_case_rule 1 - THEN (EVERY (map - (fn i => EVERY' (select_sup (length clauses) i) i) - (1 upto (length clauses)))) - THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses)) - THEN print_tac "proved one direction" -end; +fun prove_one_direction thy clauses preds modes pred mode moded_clauses = + let + val T = the (AList.lookup (op =) preds pred) + val nargs = length (binder_types T) - nparams_of thy pred + (* FIXME: preprocessing! *) + val pred_case_rule = singleton (ind_set_codegen_preproc thy) + (preprocess_elim thy nargs (the_elim_of thy pred)) + (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@{thm Predicate.memb_code}])*) + in + REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})) + THEN etac (predfun_elim_of thy pred mode) 1 + THEN etac pred_case_rule 1 + THEN (EVERY (map + (fn i => EVERY' (select_sup (length moded_clauses) i) i) + (1 upto (length moded_clauses)))) + THEN (EVERY (map2 (prove_clause thy nargs modes mode) clauses moded_clauses)) + THEN print_tac "proved one direction" + end; (** Proof in the other direction **) @@ -1303,9 +1303,10 @@ (* VERY LARGE SIMILIRATIY to function prove_param -- join both functions *) - -fun prove_param2 thy modes (NONE, t) = all_tac - | prove_param2 thy modes (m as SOME (Mode (mode, is, ms)), t) = let +(* TODO: remove function *) +(* +fun prove_param2 thy (NONE, t) = all_tac + | prove_param2 thy (m as SOME (Mode (mode, is, ms)), t) = let val (f, args) = strip_comb t val (params, _) = chop (length ms) args val f_tac = case f of @@ -1317,30 +1318,27 @@ print_tac "before simplification in prove_args:" THEN f_tac THEN print_tac "after simplification in prove_args" - THEN (EVERY (map (prove_param2 thy modes) (ms ~~ params))) + THEN (EVERY (map (prove_param2 thy) (ms ~~ params))) end +*) - -fun prove_expr2 thy modes (SOME (Mode (mode, is, ms)), t) = +fun prove_expr2 thy (Mode (mode, is, ms), t) = (case strip_comb t of (Const (name, T), args) => - if AList.defined op = modes name then - etac @{thm bindE} 1 - THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))) - THEN print_tac "prove_expr2-before" - THEN (debug_tac (Syntax.string_of_term_global thy - (prop_of (predfun_elim_of thy name mode)))) - THEN (etac (predfun_elim_of thy name mode) 1) - THEN print_tac "prove_expr2" - (* TODO -- FIXME: replace remove_last_goal*) - (* THEN (EVERY (replicate (length args) (remove_last_goal thy))) *) - THEN (EVERY (map (prove_param thy modes) (ms ~~ args))) - THEN print_tac "finished prove_expr2" - - else error "Prove expr2 if case not implemented" + etac @{thm bindE} 1 + THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))) + THEN print_tac "prove_expr2-before" + THEN (debug_tac (Syntax.string_of_term_global thy + (prop_of (predfun_elim_of thy name mode)))) + THEN (etac (predfun_elim_of thy name mode) 1) + THEN print_tac "prove_expr2" + THEN (EVERY (map (prove_param thy) (ms ~~ args))) + THEN print_tac "finished prove_expr2" | _ => etac @{thm bindE} 1) - | prove_expr2 _ _ _ = error "Prove expr2 not implemented" - + +(* FIXME: what is this for? *) +(* replace defined by has_mode thy pred *) +(* TODO: rewrite function *) fun prove_sidecond2 thy modes t = let fun preds_of t nameTs = case strip_comb t of (f as Const (name, T), args) => @@ -1358,130 +1356,113 @@ THEN print_tac "after sidecond2 simplification" end -fun prove_clause2 thy all_vs param_vs modes (iss, is) (ts, ps) pred i = let - val modes' = modes @ List.mapPartial - (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) - (param_vs ~~ iss); - fun check_constrt ((names, eqs), t) = - if is_constrt thy t then ((names, eqs), t) else - let - val s = Name.variant names "x"; - val v = Free (s, fastype_of t) - in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end; - val pred_intro_rule = nth (intros_of thy pred) (i - 1) - |> preprocess_intro thy - |> (fn thm => hd (ind_set_codegen_preproc thy [thm])) - (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *) - val (in_ts, clause_out_ts) = split_mode is ts; - val ((all_vs', eqs), in_ts') = - (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts); - fun prove_prems2 out_ts vs [] = - print_tac "before prove_match2 - last call:" - THEN prove_match2 thy out_ts - THEN print_tac "after prove_match2 - last call:" - THEN (etac @{thm singleE} 1) - THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1)) - THEN (asm_full_simp_tac HOL_basic_ss' 1) - THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1)) - THEN (asm_full_simp_tac HOL_basic_ss' 1) - THEN SOLVED (print_tac "state before applying intro rule:" +fun prove_clause2 thy modes pred (iss, is) (ts, ps) i = + let + val pred_intro_rule = nth (intros_of thy pred) (i - 1) + |> preprocess_intro thy + |> (fn thm => hd (ind_set_codegen_preproc thy [thm])) + (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *) + val (in_ts, clause_out_ts) = split_mode is ts; + fun prove_prems2 out_ts [] = + print_tac "before prove_match2 - last call:" + THEN prove_match2 thy out_ts + THEN print_tac "after prove_match2 - last call:" + THEN (etac @{thm singleE} 1) + THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1)) + THEN (asm_full_simp_tac HOL_basic_ss' 1) + THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1)) + THEN (asm_full_simp_tac HOL_basic_ss' 1) + THEN SOLVED (print_tac "state before applying intro rule:" THEN (rtac pred_intro_rule 1) (* How to handle equality correctly? *) THEN (print_tac "state before assumption matching") THEN (REPEAT (atac 1 ORELSE (CHANGED (asm_full_simp_tac HOL_basic_ss' 1) THEN print_tac "state after simp_tac:")))) - | prove_prems2 out_ts vs ps = let - val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); - val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) = - select_mode_prem thy modes' vs' ps; - val ps' = filter_out (equal p) ps; - val rest_tac = (case p of Prem (us, t) => + | prove_prems2 out_ts ((p, mode as Mode ((iss, is), _, param_modes)) :: ps) = + let + val rest_tac = (case p of + Prem (us, t) => let - val (in_ts, out_ts''') = split_mode js us - val rec_tac = prove_prems2 out_ts''' vs' ps' + val (_, out_ts''') = split_mode is us + val rec_tac = prove_prems2 out_ts''' ps in - (prove_expr2 thy modes (mode, t)) THEN rec_tac + (prove_expr2 thy (mode, t)) THEN rec_tac end | Negprem (us, t) => let - val (in_ts, out_ts''') = split_mode js us - val rec_tac = prove_prems2 out_ts''' vs' ps' + val (_, out_ts''') = split_mode is us + val rec_tac = prove_prems2 out_ts''' ps val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) val (_, params) = strip_comb t in print_tac "before neg prem 2" THEN etac @{thm bindE} 1 THEN (if is_some name then - full_simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, js)]) 1 + full_simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, is)]) 1 THEN etac @{thm not_predE} 1 - THEN (EVERY (map (prove_param thy modes) (param_modes ~~ params))) + THEN (EVERY (map (prove_param thy) (param_modes ~~ params))) else etac @{thm not_predE'} 1) THEN rec_tac end | Sidecond t => - etac @{thm bindE} 1 - THEN etac @{thm if_predE} 1 - THEN prove_sidecond2 thy modes t - THEN prove_prems2 [] vs' ps') - in print_tac "before prove_match2:" - THEN prove_match2 thy out_ts - THEN print_tac "after prove_match2:" - THEN rest_tac - end; - val prems_tac = prove_prems2 in_ts' param_vs ps -in - print_tac "starting prove_clause2" - THEN etac @{thm bindE} 1 - THEN (etac @{thm singleE'} 1) - THEN (TRY (etac @{thm Pair_inject} 1)) - THEN print_tac "after singleE':" - THEN prems_tac -end; + etac @{thm bindE} 1 + THEN etac @{thm if_predE} 1 + THEN prove_sidecond2 thy modes t + THEN prove_prems2 [] ps) + in print_tac "before prove_match2:" + THEN prove_match2 thy out_ts + THEN print_tac "after prove_match2:" + THEN rest_tac + end; + val prems_tac = prove_prems2 in_ts ps + in + print_tac "starting prove_clause2" + THEN etac @{thm bindE} 1 + THEN (etac @{thm singleE'} 1) + THEN (TRY (etac @{thm Pair_inject} 1)) + THEN print_tac "after singleE':" + THEN prems_tac + end; -fun prove_other_direction thy all_vs param_vs modes clauses (pred, mode) = let - fun prove_clause (clause, i) = - (if i < length clauses then etac @{thm supE} 1 else all_tac) - THEN (prove_clause2 thy all_vs param_vs modes mode clause pred i) -in - (DETERM (TRY (rtac @{thm unit.induct} 1))) - THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all}))) - THEN (rtac (predfun_intro_of thy pred mode) 1) - THEN (REPEAT_DETERM (rtac @{thm refl} 2)) - THEN (EVERY (map prove_clause (clauses ~~ (1 upto (length clauses))))) -end; +fun prove_other_direction thy modes pred mode moded_clauses = + let + fun prove_clause clause i = + (if i < length moded_clauses then etac @{thm supE} 1 else all_tac) + THEN (prove_clause2 thy modes pred mode clause i) + in + (DETERM (TRY (rtac @{thm unit.induct} 1))) + THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all}))) + THEN (rtac (predfun_intro_of thy pred mode) 1) + THEN (REPEAT_DETERM (rtac @{thm refl} 2)) + THEN (EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses)))) + end; (** proof procedure **) -fun prove_pred thy all_vs param_vs modes clauses (((pred, T), mode), t) = let - val ctxt = ProofContext.init thy - val clauses' = the (AList.lookup (op =) clauses pred) -in - Goal.prove ctxt (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) t []) [] t - (if !do_proofs then - (fn _ => - rtac @{thm pred_iffI} 1 - THEN prove_one_direction thy all_vs param_vs modes clauses' ((pred, T), mode) - THEN print_tac "proved one direction" - THEN prove_other_direction thy all_vs param_vs modes clauses' (pred, mode) - THEN print_tac "proved other direction") - else (fn _ => mycheat_tac thy 1)) -end; +fun prove_pred thy clauses preds modes pred mode (moded_clauses, compiled_term) = + let + val ctxt = ProofContext.init thy + val clauses = the (AList.lookup (op =) clauses pred) + in + Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term + (if !do_proofs then + (fn _ => + rtac @{thm pred_iffI} 1 + THEN prove_one_direction thy clauses preds modes pred mode moded_clauses + THEN print_tac "proved one direction" + THEN prove_other_direction thy modes pred mode moded_clauses + THEN print_tac "proved other direction") + else (fn _ => mycheat_tac thy 1)) + end; -fun prove_preds thy all_vs param_vs modes clauses pmts = - map (prove_pred thy all_vs param_vs modes clauses) pmts +fun prove_preds thy clauses preds modes = + map_preds_modes (prove_pred thy clauses preds modes) -fun rpred_prove_preds thy pmts = - let - fun prove_pred (((pred, T), mode), t) = - let - val _ = Output.tracing ("prove_preds:" ^ Syntax.string_of_term_global thy t) - in SkipProof.make_thm thy t end - in - map prove_pred pmts - end - +fun rpred_prove_preds thy = + map_preds_modes (fn pred => fn mode => fn t => SkipProof.make_thm thy t) + fun prepare_intrs thy prednames = let val intrs = maps (intros_of thy) prednames @@ -1526,17 +1507,6 @@ val (clauses, arities) = fold add_clause intrs ([], []); in (preds, nparams, all_vs, param_vs, extra_modes, clauses, arities) end; -fun arrange kvs = - let - fun add (key, value) table = - AList.update op = (key, these (AList.lookup op = table key) @ [value]) table - in fold add kvs [] end; - -fun add_generators_to_clauses thy all_vs clauses = - let - val _ = Output.tracing ("all_vs:" ^ commas all_vs) - in [clauses] end; - (* main function *) fun add_equations_of rpred prednames thy = @@ -1545,43 +1515,29 @@ val (preds, nparams, all_vs, param_vs, extra_modes, clauses, arities) = prepare_intrs thy prednames val _ = tracing "Infering modes..." - (* - val clauses_with_generators = add_generators_to_clauses thy all_vs clauses - val modess = map (infer_modes thy extra_modes arities param_vs) clauses_with_generators - fun print_modess (clauses, modes) = - let - val _ = print_clausess thy clauses - val _ = print_modes modes - in - () - end; - val _ = map print_modess (clauses_with_generators ~~ modess) - *) - val modes = infer_modes thy extra_modes arities param_vs clauses - val _ = print_modes modes + val moded_clauses = infer_modes thy extra_modes arities param_vs clauses + val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses val _ = tracing "Defining executable functions..." val thy' = (if rpred then fold (rpred_create_definitions preds nparams) modes thy else fold (create_definitions preds nparams) modes thy) |> Theory.checkpoint - val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses val _ = tracing "Compiling equations..." val compfuns = if rpred then RPredCompFuns.compfuns else PredicateCompFuns.compfuns - val ts = compile_preds thy' compfuns all_vs param_vs (extra_modes @ modes) clauses' - val _ = map (Output.tracing o (Syntax.string_of_term_global thy')) (flat ts) - val pred_mode = - maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses' + val compiled_terms = + compile_preds thy' compfuns all_vs param_vs (extra_modes @ modes) preds moded_clauses val _ = tracing "Proving equations..." val result_thms = if rpred then - rpred_prove_preds thy' (pred_mode ~~ (flat ts)) + rpred_prove_preds thy' compiled_terms else - prove_preds thy' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts)) + prove_preds thy' clauses preds (extra_modes @ modes) + (join_preds_modes moded_clauses compiled_terms) val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss [((Binding.qualify true (Long_Name.base_name name) (Binding.name "equation"), result_thms), [Attrib.attribute_i thy Code.add_default_eqn_attrib])] thy)) - (arrange ((map (fn ((name, _), _) => name) pred_mode) ~~ result_thms)) thy' + (maps_modes result_thms) thy' |> Theory.checkpoint in thy'' @@ -1620,7 +1576,7 @@ val scc = strong_conn_of (PredData.get thy') [name] val thy'' = fold_rev (fn preds => fn thy => - if forall (null o modes_of thy) preds then add_equations_of true preds thy else thy) + if forall (null o modes_of thy) preds then add_equations_of false preds thy else thy) scc thy' |> Theory.checkpoint in thy'' end @@ -1717,7 +1673,8 @@ | [m] => m | m :: _ :: _ => (warning ("Multiple modes possible for comprehension " ^ Syntax.string_of_term_global thy t_compr); m); - val t_eval = list_comb (compile_expr thy PredicateCompFuns.compfuns (all_modes_of thy) (SOME m, list_comb (pred, params)), + val t_eval = list_comb (compile_expr thy PredicateCompFuns.compfuns (all_modes_of thy) + (m, list_comb (pred, params)), inargs) in t_eval end;