# HG changeset patch # User bulwahn # Date 1287681189 -7200 # Node ID ea46574ca81581a72a1b4ee2236c1df71159a0f7 # Parent b6acda4d1c29186c5895196562387a3da1e71793 splitting large core file into core_data, mode_inference and predicate_compile_proof diff -r b6acda4d1c29 -r ea46574ca815 src/HOL/Predicate_Compile.thy --- a/src/HOL/Predicate_Compile.thy Thu Oct 21 19:13:09 2010 +0200 +++ b/src/HOL/Predicate_Compile.thy Thu Oct 21 19:13:09 2010 +0200 @@ -9,6 +9,9 @@ uses "Tools/Predicate_Compile/predicate_compile_aux.ML" "Tools/Predicate_Compile/predicate_compile_compilations.ML" + "Tools/Predicate_Compile/core_data.ML" + "Tools/Predicate_Compile/mode_inference.ML" + "Tools/Predicate_Compile/predicate_compile_proof.ML" "Tools/Predicate_Compile/predicate_compile_core.ML" "Tools/Predicate_Compile/predicate_compile_data.ML" "Tools/Predicate_Compile/predicate_compile_fun.ML" diff -r b6acda4d1c29 -r ea46574ca815 src/HOL/Tools/Predicate_Compile/core_data.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Predicate_Compile/core_data.ML Thu Oct 21 19:13:09 2010 +0200 @@ -0,0 +1,179 @@ + +structure Core_Data = +struct + +open Predicate_Compile_Aux; + +(* book-keeping *) + +datatype predfun_data = PredfunData of { + definition : thm, + intro : thm, + elim : thm, + neg_intro : thm option +}; + +fun rep_predfun_data (PredfunData data) = data; + +fun mk_predfun_data (definition, ((intro, elim), neg_intro)) = + PredfunData {definition = definition, intro = intro, elim = elim, neg_intro = neg_intro} + +datatype pred_data = PredData of { + intros : (string option * thm) list, + elim : thm option, + preprocessed : bool, + function_names : (compilation * (mode * string) list) list, + predfun_data : (mode * predfun_data) list, + needs_random : mode list +}; + +fun rep_pred_data (PredData data) = data; + +fun mk_pred_data (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random))) = + PredData {intros = intros, elim = elim, preprocessed = preprocessed, + function_names = function_names, predfun_data = predfun_data, needs_random = needs_random} + +fun map_pred_data f (PredData {intros, elim, preprocessed, function_names, predfun_data, needs_random}) = + mk_pred_data (f (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random)))) + +fun eq_option eq (NONE, NONE) = true + | eq_option eq (SOME x, SOME y) = eq (x, y) + | eq_option eq _ = false + +fun eq_pred_data (PredData d1, PredData d2) = + eq_list (eq_pair (op =) Thm.eq_thm) (#intros d1, #intros d2) andalso + eq_option Thm.eq_thm (#elim d1, #elim d2) + +structure PredData = Theory_Data +( + type T = pred_data Graph.T; + val empty = Graph.empty; + val extend = I; + val merge = Graph.merge eq_pred_data; +); + +(* queries *) + +fun lookup_pred_data ctxt name = + Option.map rep_pred_data (try (Graph.get_node (PredData.get (ProofContext.theory_of ctxt))) name) + +fun the_pred_data ctxt name = case lookup_pred_data ctxt name + of NONE => error ("No such predicate " ^ quote name) + | SOME data => data; + +val is_registered = is_some oo lookup_pred_data + +val all_preds_of = Graph.keys o PredData.get o ProofContext.theory_of + +val intros_of = map snd o #intros oo the_pred_data + +val names_of = map fst o #intros oo the_pred_data + +fun the_elim_of ctxt name = case #elim (the_pred_data ctxt name) + of NONE => error ("No elimination rule for predicate " ^ quote name) + | SOME thm => thm + +val has_elim = is_some o #elim oo the_pred_data + +fun function_names_of compilation ctxt name = + case AList.lookup (op =) (#function_names (the_pred_data ctxt name)) compilation of + NONE => error ("No " ^ string_of_compilation compilation + ^ " functions defined for predicate " ^ quote name) + | SOME fun_names => fun_names + +fun function_name_of compilation ctxt name mode = + case AList.lookup eq_mode + (function_names_of compilation ctxt name) mode of + NONE => error ("No " ^ string_of_compilation compilation + ^ " function defined for mode " ^ string_of_mode mode ^ " of predicate " ^ quote name) + | SOME function_name => function_name + +fun modes_of compilation ctxt name = map fst (function_names_of compilation ctxt name) + +fun all_modes_of compilation ctxt = + map_filter (fn name => Option.map (pair name) (try (modes_of compilation ctxt) name)) + (all_preds_of ctxt) + +val all_random_modes_of = all_modes_of Random + +fun defined_functions compilation ctxt name = case lookup_pred_data ctxt name of + NONE => false + | SOME data => AList.defined (op =) (#function_names data) compilation + +fun needs_random ctxt s m = + member (op =) (#needs_random (the_pred_data ctxt s)) m + +fun lookup_predfun_data ctxt name mode = + Option.map rep_predfun_data + (AList.lookup (op =) (#predfun_data (the_pred_data ctxt name)) mode) + +fun the_predfun_data ctxt name mode = + case lookup_predfun_data ctxt name mode of + NONE => error ("No function defined for mode " ^ string_of_mode mode ^ + " of predicate " ^ name) + | SOME data => data; + +val predfun_definition_of = #definition ooo the_predfun_data + +val predfun_intro_of = #intro ooo the_predfun_data + +val predfun_elim_of = #elim ooo the_predfun_data + +val predfun_neg_intro_of = #neg_intro ooo the_predfun_data + +val intros_graph_of = + Graph.map (K (map snd o #intros o rep_pred_data)) o PredData.get o ProofContext.theory_of + +(* registration of alternative function names *) + +structure Alt_Compilations_Data = Theory_Data +( + type T = (mode * (compilation_funs -> typ -> term)) list Symtab.table; + val empty = Symtab.empty; + val extend = I; + fun merge data : T = Symtab.merge (K true) data; +); + +fun alternative_compilation_of_global thy pred_name mode = + AList.lookup eq_mode (Symtab.lookup_list (Alt_Compilations_Data.get thy) pred_name) mode + +fun alternative_compilation_of ctxt pred_name mode = + AList.lookup eq_mode + (Symtab.lookup_list (Alt_Compilations_Data.get (ProofContext.theory_of ctxt)) pred_name) mode + +fun force_modes_and_compilations pred_name compilations = + let + (* thm refl is a dummy thm *) + val modes = map fst compilations + val (needs_random, non_random_modes) = pairself (map fst) + (List.partition (fn (m, (fun_name, random)) => random) compilations) + val non_random_dummys = map (rpair "dummy") non_random_modes + val all_dummys = map (rpair "dummy") modes + val dummy_function_names = map (rpair all_dummys) Predicate_Compile_Aux.random_compilations + @ map (rpair non_random_dummys) Predicate_Compile_Aux.non_random_compilations + val alt_compilations = map (apsnd fst) compilations + in + PredData.map (Graph.new_node + (pred_name, mk_pred_data ((([], SOME @{thm refl}), true), (dummy_function_names, ([], needs_random))))) + #> Alt_Compilations_Data.map (Symtab.insert (K false) (pred_name, alt_compilations)) + end + +fun functional_compilation fun_name mode compfuns T = + let + val (inpTs, outpTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) + mode (binder_types T) + val bs = map (pair "x") inpTs + val bounds = map Bound (rev (0 upto (length bs) - 1)) + val f = Const (fun_name, inpTs ---> HOLogic.mk_tupleT outpTs) + in list_abs (bs, mk_single compfuns (list_comb (f, bounds))) end + +fun register_alternative_function pred_name mode fun_name = + Alt_Compilations_Data.map (Symtab.insert_list (eq_pair eq_mode (K false)) + (pred_name, (mode, functional_compilation fun_name mode))) + +fun force_modes_and_functions pred_name fun_names = + force_modes_and_compilations pred_name + (map (fn (mode, (fun_name, random)) => (mode, (functional_compilation fun_name mode, random))) + fun_names) + +end; \ No newline at end of file diff -r b6acda4d1c29 -r ea46574ca815 src/HOL/Tools/Predicate_Compile/mode_inference.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Predicate_Compile/mode_inference.ML Thu Oct 21 19:13:09 2010 +0200 @@ -0,0 +1,643 @@ +(* Title: HOL/Tools/Predicate_Compile/mode_inference.ML + Author: Lukas Bulwahn, TU Muenchen + +Mode inference for the predicate compiler + +*) + +signature MODE_INFERENCE = +sig + type mode = Predicate_Compile_Aux.mode + + (* options *) + type mode_analysis_options = + {use_generators : bool, + reorder_premises : bool, + infer_pos_and_neg_modes : bool} + + (* mode operation *) + val all_input_of : typ -> mode + (* mode derivation and operations *) + datatype mode_derivation = Mode_App of mode_derivation * mode_derivation | Context of mode + | Mode_Pair of mode_derivation * mode_derivation | Term of mode + + val head_mode_of : mode_derivation -> mode + val param_derivations_of : mode_derivation -> mode_derivation list + val collect_context_modes : mode_derivation -> mode list + + type moded_clause = term list * (Predicate_Compile_Aux.indprem * mode_derivation) list + type 'a pred_mode_table = (string * ((bool * mode) * 'a) list) list + + (* mode inference operations *) + val all_derivations_of : + Proof.context -> (string * mode list) list -> string list -> term + -> (mode_derivation * string list) list + (* TODO: move all_modes creation to infer_modes *) + val infer_modes : + mode_analysis_options -> Predicate_Compile_Aux.options -> + (string -> mode list) * (string -> mode list) + * (string -> mode -> bool) -> Proof.context -> (string * typ) list -> + (string * mode list) list -> + string list -> (string * (Term.term list * Predicate_Compile_Aux.indprem list) list) list -> + ((moded_clause list pred_mode_table * (string * mode list) list) * string list) + + (* mode and term operations -- to be moved to Predicate_Compile_Aux *) + val collect_non_invertible_subterms : + Proof.context -> term -> string list * term list -> (term * (string list * term list)) + val is_all_input : mode -> bool + val term_vs : term -> string list + val terms_vs : term list -> string list + +end; + +structure Mode_Inference : MODE_INFERENCE = +struct + +open Predicate_Compile_Aux; +open Core_Data; + +(* derivation trees for modes of premises *) + +datatype mode_derivation = Mode_App of mode_derivation * mode_derivation | Context of mode + | Mode_Pair of mode_derivation * mode_derivation | Term of mode + +fun string_of_derivation (Mode_App (m1, m2)) = + "App (" ^ string_of_derivation m1 ^ ", " ^ string_of_derivation m2 ^ ")" + | string_of_derivation (Mode_Pair (m1, m2)) = + "Pair (" ^ string_of_derivation m1 ^ ", " ^ string_of_derivation m2 ^ ")" + | string_of_derivation (Term m) = "Term (" ^ string_of_mode m ^ ")" + | string_of_derivation (Context m) = "Context (" ^ string_of_mode m ^ ")" + +fun strip_mode_derivation deriv = + let + fun strip (Mode_App (deriv1, deriv2)) ds = strip deriv1 (deriv2 :: ds) + | strip deriv ds = (deriv, ds) + in + strip deriv [] + end + +fun mode_of (Context m) = m + | mode_of (Term m) = m + | mode_of (Mode_App (d1, d2)) = + (case mode_of d1 of Fun (m, m') => + (if eq_mode (m, mode_of d2) then m' else raise Fail "mode_of: derivation has mismatching modes") + | _ => raise Fail "mode_of: derivation has a non-functional mode") + | mode_of (Mode_Pair (d1, d2)) = + Pair (mode_of d1, mode_of d2) + +fun head_mode_of deriv = mode_of (fst (strip_mode_derivation deriv)) + +fun param_derivations_of deriv = + let + val (_, argument_derivs) = strip_mode_derivation deriv + fun param_derivation (Mode_Pair (m1, m2)) = + param_derivation m1 @ param_derivation m2 + | param_derivation (Term _) = [] + | param_derivation m = [m] + in + maps param_derivation argument_derivs + end + +fun collect_context_modes (Mode_App (m1, m2)) = + collect_context_modes m1 @ collect_context_modes m2 + | collect_context_modes (Mode_Pair (m1, m2)) = + collect_context_modes m1 @ collect_context_modes m2 + | collect_context_modes (Context m) = [m] + | collect_context_modes (Term _) = [] + +type moded_clause = term list * (Predicate_Compile_Aux.indprem * mode_derivation) list +type 'a pred_mode_table = (string * ((bool * mode) * 'a) list) list + + +(** string_of functions **) + +fun string_of_prem ctxt (Prem t) = + (Syntax.string_of_term ctxt t) ^ "(premise)" + | string_of_prem ctxt (Negprem t) = + (Syntax.string_of_term ctxt (HOLogic.mk_not t)) ^ "(negative premise)" + | string_of_prem ctxt (Sidecond t) = + (Syntax.string_of_term ctxt t) ^ "(sidecondition)" + | string_of_prem ctxt _ = raise Fail "string_of_prem: unexpected input" + +fun string_of_clause ctxt pred (ts, prems) = + (space_implode " --> " + (map (string_of_prem ctxt) prems)) ^ " --> " ^ pred ^ " " + ^ (space_implode " " (map (Syntax.string_of_term ctxt) ts)) + +type mode_analysis_options = + {use_generators : bool, + reorder_premises : bool, + infer_pos_and_neg_modes : bool} + +fun is_constrt thy = + let + val cnstrs = flat (maps + (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd) + (Symtab.dest (Datatype.get_all thy))); + fun check t = (case strip_comb t of + (Free _, []) => true + | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of + (SOME (i, Tname), Type (Tname', _)) => + length ts = i andalso Tname = Tname' andalso forall check ts + | _ => false) + | _ => false) + in check end; + +(*** check if a type is an equality type (i.e. doesn't contain fun) + FIXME this is only an approximation ***) +fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts + | is_eqT _ = true; + +fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm []; +val terms_vs = distinct (op =) o maps term_vs; + +(** collect all Frees in a term (with duplicates!) **) +fun term_vTs tm = + fold_aterms (fn Free xT => cons xT | _ => I) tm []; + +fun subsets i j = + if i <= j then + let + fun merge xs [] = xs + | merge [] ys = ys + | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys) + else y::merge (x::xs) ys; + val is = subsets (i+1) j + in merge (map (fn ks => i::ks) is) is end + else [[]]; + +fun print_failed_mode options thy modes p (pol, m) rs is = + if show_mode_inference options then + let + val _ = tracing ("Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^ + p ^ " violates mode " ^ string_of_mode m) + in () end + else () + +fun error_of p (pol, m) is = + " Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^ + p ^ " violates mode " ^ string_of_mode m + +fun is_all_input mode = + let + fun is_all_input' (Fun _) = true + | is_all_input' (Pair (m1, m2)) = is_all_input' m1 andalso is_all_input' m2 + | is_all_input' Input = true + | is_all_input' Output = false + in + forall is_all_input' (strip_fun_mode mode) + end + +fun all_input_of T = + let + val (Ts, U) = strip_type T + fun input_of (Type (@{type_name Product_Type.prod}, [T1, T2])) = Pair (input_of T1, input_of T2) + | input_of _ = Input + in + if U = HOLogic.boolT then + fold_rev (curry Fun) (map input_of Ts) Bool + else + raise Fail "all_input_of: not a predicate" + end + +fun find_least ord xs = + let + fun find' x y = (case y of NONE => SOME x | SOME y' => if ord (x, y') = LESS then SOME x else y) + in + fold find' xs NONE + end + +fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm []; +val terms_vs = distinct (op =) o maps term_vs; + +fun input_mode T = + let + val (Ts, U) = strip_type T + in + fold_rev (curry Fun) (map (K Input) Ts) Input + end + +fun output_mode T = + let + val (Ts, U) = strip_type T + in + fold_rev (curry Fun) (map (K Output) Ts) Output + end + +fun is_invertible_function ctxt (Const (f, _)) = is_constr ctxt f + | is_invertible_function ctxt _ = false + +fun non_invertible_subterms ctxt (t as Free _) = [] + | non_invertible_subterms ctxt t = + let + val (f, args) = strip_comb t + in + if is_invertible_function ctxt f then + maps (non_invertible_subterms ctxt) args + else + [t] + end + +fun collect_non_invertible_subterms ctxt (f as Free _) (names, eqs) = (f, (names, eqs)) + | collect_non_invertible_subterms ctxt t (names, eqs) = + case (strip_comb t) of (f, args) => + if is_invertible_function ctxt f then + let + val (args', (names', eqs')) = + fold_map (collect_non_invertible_subterms ctxt) args (names, eqs) + in + (list_comb (f, args'), (names', eqs')) + end + else + let + val s = Name.variant names "x" + val v = Free (s, fastype_of t) + in + (v, (s :: names, HOLogic.mk_eq (v, t) :: eqs)) + end +(* + if is_constrt thy t then (t, (names, eqs)) else + let + val s = Name.variant names "x" + val v = Free (s, fastype_of t) + in (v, (s::names, HOLogic.mk_eq (v, t)::eqs)) end; +*) + +fun is_possible_output ctxt vs t = + forall + (fn t => is_eqT (fastype_of t) andalso forall (member (op =) vs) (term_vs t)) + (non_invertible_subterms ctxt t) + andalso + (forall (is_eqT o snd) + (inter (fn ((f', _), f) => f = f') vs (Term.add_frees t []))) + +fun vars_of_destructable_term ctxt (Free (x, _)) = [x] + | vars_of_destructable_term ctxt t = + let + val (f, args) = strip_comb t + in + if is_invertible_function ctxt f then + maps (vars_of_destructable_term ctxt) args + else + [] + end + +fun is_constructable vs t = forall (member (op =) vs) (term_vs t) + +fun missing_vars vs t = subtract (op =) vs (term_vs t) + +fun output_terms (Const (@{const_name Pair}, _) $ t1 $ t2, Mode_Pair (d1, d2)) = + output_terms (t1, d1) @ output_terms (t2, d2) + | output_terms (t1 $ t2, Mode_App (d1, d2)) = + output_terms (t1, d1) @ output_terms (t2, d2) + | output_terms (t, Term Output) = [t] + | output_terms _ = [] + +fun lookup_mode modes (Const (s, T)) = + (case (AList.lookup (op =) modes s) of + SOME ms => SOME (map (fn m => (Context m, [])) ms) + | NONE => NONE) + | lookup_mode modes (Free (x, _)) = + (case (AList.lookup (op =) modes x) of + SOME ms => SOME (map (fn m => (Context m , [])) ms) + | NONE => NONE) + +fun derivations_of (ctxt : Proof.context) modes vs (Const (@{const_name Pair}, _) $ t1 $ t2) (Pair (m1, m2)) = + map_product + (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2)) + (derivations_of ctxt modes vs t1 m1) (derivations_of ctxt modes vs t2 m2) + | derivations_of ctxt modes vs t (m as Fun _) = + (*let + val (p, args) = strip_comb t + in + (case lookup_mode modes p of + SOME ms => map_filter (fn (Context m, []) => let + val ms = strip_fun_mode m + val (argms, restms) = chop (length args) ms + val m' = fold_rev (curry Fun) restms Bool + in + if forall (fn m => eq_mode (Input, m)) argms andalso eq_mode (m', mode) then + SOME (fold (curry Mode_App) (map Term argms) (Context m), missing_vars vs t) + else NONE + end) ms + | NONE => (if is_all_input mode then [(Context mode, [])] else [])) + end*) + (case try (all_derivations_of ctxt modes vs) t of + SOME derivs => + filter (fn (d, mvars) => eq_mode (mode_of d, m) andalso null (output_terms (t, d))) derivs + | NONE => (if is_all_input m then [(Context m, [])] else [])) + | derivations_of ctxt modes vs t m = + if eq_mode (m, Input) then + [(Term Input, missing_vars vs t)] + else if eq_mode (m, Output) then + (if is_possible_output ctxt vs t then [(Term Output, [])] else []) + else [] +and all_derivations_of ctxt modes vs (Const (@{const_name Pair}, _) $ t1 $ t2) = + let + val derivs1 = all_derivations_of ctxt modes vs t1 + val derivs2 = all_derivations_of ctxt modes vs t2 + in + map_product + (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2)) + derivs1 derivs2 + end + | all_derivations_of ctxt modes vs (t1 $ t2) = + let + val derivs1 = all_derivations_of ctxt modes vs t1 + in + maps (fn (d1, mvars1) => + case mode_of d1 of + Fun (m', _) => map (fn (d2, mvars2) => + (Mode_App (d1, d2), union (op =) mvars1 mvars2)) (derivations_of ctxt modes vs t2 m') + | _ => raise Fail "all_derivations_of: derivation has an unexpected non-functional mode") derivs1 + end + | all_derivations_of _ modes vs (Const (s, T)) = the (lookup_mode modes (Const (s, T))) + | all_derivations_of _ modes vs (Free (x, T)) = the (lookup_mode modes (Free (x, T))) + | all_derivations_of _ modes vs _ = raise Fail "all_derivations_of: unexpected term" + +fun rev_option_ord ord (NONE, NONE) = EQUAL + | rev_option_ord ord (NONE, SOME _) = GREATER + | rev_option_ord ord (SOME _, NONE) = LESS + | rev_option_ord ord (SOME x, SOME y) = ord (x, y) + +fun random_mode_in_deriv modes t deriv = + case try dest_Const (fst (strip_comb t)) of + SOME (s, _) => + (case AList.lookup (op =) modes s of + SOME ms => + (case AList.lookup (op =) (map (fn ((p, m), r) => (m, r)) ms) (head_mode_of deriv) of + SOME r => r + | NONE => false) + | NONE => false) + | NONE => false + +fun number_of_output_positions mode = + let + val args = strip_fun_mode mode + fun contains_output (Fun _) = false + | contains_output Input = false + | contains_output Output = true + | contains_output (Pair (m1, m2)) = contains_output m1 orelse contains_output m2 + in + length (filter contains_output args) + end + +fun lex_ord ord1 ord2 (x, x') = + case ord1 (x, x') of + EQUAL => ord2 (x, x') + | ord => ord + +fun lexl_ord [] (x, x') = EQUAL + | lexl_ord (ord :: ords') (x, x') = + case ord (x, x') of + EQUAL => lexl_ord ords' (x, x') + | ord => ord + +fun deriv_ord' ctxt pol pred modes t1 t2 ((deriv1, mvars1), (deriv2, mvars2)) = + let + (* prefer functional modes if it is a function *) + fun fun_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = + let + fun is_functional t mode = + case try (fst o dest_Const o fst o strip_comb) t of + NONE => false + | SOME c => is_some (alternative_compilation_of ctxt c mode) + in + case (is_functional t1 (head_mode_of deriv1), is_functional t2 (head_mode_of deriv2)) of + (true, true) => EQUAL + | (true, false) => LESS + | (false, true) => GREATER + | (false, false) => EQUAL + end + (* prefer modes without requirement for generating random values *) + fun mvars_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = + int_ord (length mvars1, length mvars2) + (* prefer non-random modes *) + fun random_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = + int_ord (if random_mode_in_deriv modes t1 deriv1 then 1 else 0, + if random_mode_in_deriv modes t2 deriv2 then 1 else 0) + (* prefer modes with more input and less output *) + fun output_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = + int_ord (number_of_output_positions (head_mode_of deriv1), + number_of_output_positions (head_mode_of deriv2)) + (* prefer recursive calls *) + fun is_rec_premise t = + case fst (strip_comb t) of Const (c, T) => c = pred | _ => false + fun recursive_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = + int_ord (if is_rec_premise t1 then 0 else 1, + if is_rec_premise t2 then 0 else 1) + val ord = lexl_ord [mvars_ord, fun_ord, random_mode_ord, output_mode_ord, recursive_ord] + in + ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) + end + +fun deriv_ord ctxt pol pred modes t = deriv_ord' ctxt pol pred modes t t + +fun premise_ord thy pol pred modes ((prem1, a1), (prem2, a2)) = + rev_option_ord (deriv_ord' thy pol pred modes (dest_indprem prem1) (dest_indprem prem2)) (a1, a2) + +fun print_mode_list modes = + tracing ("modes: " ^ (commas (map (fn (s, ms) => s ^ ": " ^ + commas (map (fn (m, r) => string_of_mode m ^ (if r then " random " else " not ")) ms)) modes))) + +fun select_mode_prem (mode_analysis_options : mode_analysis_options) (ctxt : Proof.context) pred + pol (modes, (pos_modes, neg_modes)) vs ps = + let + fun choose_mode_of_prem (Prem t) = + find_least (deriv_ord ctxt pol pred modes t) (all_derivations_of ctxt pos_modes vs t) + | choose_mode_of_prem (Sidecond t) = SOME (Context Bool, missing_vars vs t) + | choose_mode_of_prem (Negprem t) = find_least (deriv_ord ctxt (not pol) pred modes t) + (filter (fn (d, missing_vars) => is_all_input (head_mode_of d)) + (all_derivations_of ctxt neg_modes vs t)) + | choose_mode_of_prem p = raise Fail ("choose_mode_of_prem: unexpected premise " ^ string_of_prem ctxt p) + in + if #reorder_premises mode_analysis_options then + find_least (premise_ord ctxt pol pred modes) (ps ~~ map choose_mode_of_prem ps) + else + SOME (hd ps, choose_mode_of_prem (hd ps)) + end + +fun check_mode_clause' (mode_analysis_options : mode_analysis_options) ctxt pred param_vs (modes : + (string * ((bool * mode) * bool) list) list) ((pol, mode) : bool * mode) (ts, ps) = + let + val vTs = distinct (op =) (fold Term.add_frees (map dest_indprem ps) (fold Term.add_frees ts [])) + val modes' = modes @ (param_vs ~~ map (fn x => [((true, x), false), ((false, x), false)]) (ho_arg_modes_of mode)) + fun retrieve_modes_of_pol pol = map (fn (s, ms) => + (s, map_filter (fn ((p, m), r) => if p = pol then SOME m else NONE | _ => NONE) ms)) + val (pos_modes', neg_modes') = + if #infer_pos_and_neg_modes mode_analysis_options then + (retrieve_modes_of_pol pol modes', retrieve_modes_of_pol (not pol) modes') + else + let + val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes' + in (modes, modes) end + val (in_ts, out_ts) = split_mode mode ts + val in_vs = union (op =) param_vs (maps (vars_of_destructable_term ctxt) in_ts) + val out_vs = terms_vs out_ts + fun known_vs_after p vs = (case p of + Prem t => union (op =) vs (term_vs t) + | Sidecond t => union (op =) vs (term_vs t) + | Negprem t => union (op =) vs (term_vs t) + | _ => raise Fail "unexpected premise") + fun check_mode_prems acc_ps rnd vs [] = SOME (acc_ps, vs, rnd) + | check_mode_prems acc_ps rnd vs ps = + (case + (select_mode_prem mode_analysis_options ctxt pred pol (modes', (pos_modes', neg_modes')) vs ps) of + SOME (p, SOME (deriv, [])) => check_mode_prems ((p, deriv) :: acc_ps) rnd + (known_vs_after p vs) (filter_out (equal p) ps) + | SOME (p, SOME (deriv, missing_vars)) => + if #use_generators mode_analysis_options andalso pol then + check_mode_prems ((p, deriv) :: (map + (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output)) + (distinct (op =) missing_vars)) + @ acc_ps) true (known_vs_after p vs) (filter_out (equal p) ps) + else NONE + | SOME (p, NONE) => NONE + | NONE => NONE) + in + case check_mode_prems [] false in_vs ps of + NONE => NONE + | SOME (acc_ps, vs, rnd) => + if forall (is_constructable vs) (in_ts @ out_ts) then + SOME (ts, rev acc_ps, rnd) + else + if #use_generators mode_analysis_options andalso pol then + let + val generators = map + (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output)) + (subtract (op =) vs (terms_vs (in_ts @ out_ts))) + in + SOME (ts, rev (generators @ acc_ps), true) + end + else + NONE + end + +datatype result = Success of bool | Error of string + +fun check_modes_pred' mode_analysis_options options thy param_vs clauses modes (p, (ms : ((bool * mode) * bool) list)) = + let + fun split xs = + let + fun split' [] (ys, zs) = (rev ys, rev zs) + | split' ((m, Error z) :: xs) (ys, zs) = split' xs (ys, z :: zs) + | split' (((m : bool * mode), Success rnd) :: xs) (ys, zs) = split' xs ((m, rnd) :: ys, zs) + in + split' xs ([], []) + end + val rs = these (AList.lookup (op =) clauses p) + fun check_mode m = + let + val res = Output.cond_timeit false "work part of check_mode for one mode" (fn _ => + map (check_mode_clause' mode_analysis_options thy p param_vs modes m) rs) + in + Output.cond_timeit false "aux part of check_mode for one mode" (fn _ => + case find_indices is_none res of + [] => Success (exists (fn SOME (_, _, true) => true | _ => false) res) + | is => (print_failed_mode options thy modes p m rs is; Error (error_of p m is))) + end + val _ = if show_mode_inference options then + tracing ("checking " ^ string_of_int (length ms) ^ " modes ...") + else () + val res = Output.cond_timeit false "check_mode" (fn _ => map (fn (m, _) => (m, check_mode m)) ms) + val (ms', errors) = split res + in + ((p, (ms' : ((bool * mode) * bool) list)), errors) + end; + +fun get_modes_pred' mode_analysis_options thy param_vs clauses modes (p, ms) = + let + val rs = these (AList.lookup (op =) clauses p) + in + (p, map (fn (m, rnd) => + (m, map + ((fn (ts, ps, rnd) => (ts, ps)) o the o + check_mode_clause' mode_analysis_options thy p param_vs modes m) rs)) ms) + end; + +fun fixp f (x : (string * ((bool * mode) * bool) list) list) = + let val y = f x + in if x = y then x else fixp f y end; + +fun fixp_with_state f (x : (string * ((bool * mode) * bool) list) list, state) = + let + val (y, state') = f (x, state) + in + if x = y then (y, state') else fixp_with_state f (y, state') + end + +fun string_of_ext_mode ((pol, mode), rnd) = + string_of_mode mode ^ "(" ^ (if pol then "pos" else "neg") ^ ", " + ^ (if rnd then "rnd" else "nornd") ^ ")" + +fun print_extra_modes options modes = + if show_mode_inference options then + tracing ("Modes of inferred predicates: " ^ + cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map string_of_ext_mode ms)) modes)) + else () + +fun infer_modes mode_analysis_options options (lookup_mode, lookup_neg_mode, needs_random) ctxt + preds all_modes param_vs clauses = + let + fun appair f (x1, x2) (y1, y2) = (f x1 y1, f x2 y2) + fun add_needs_random s (false, m) = ((false, m), false) + | add_needs_random s (true, m) = ((true, m), needs_random s m) + fun add_polarity_and_random_bit s b ms = map (fn m => add_needs_random s (b, m)) ms + val prednames = map fst preds + (* extramodes contains all modes of all constants, should we only use the necessary ones + - what is the impact on performance? *) + fun predname_of (Prem t) = + (case try dest_Const (fst (strip_comb t)) of SOME (c, _) => insert (op =) c | NONE => I) + | predname_of (Negprem t) = + (case try dest_Const (fst (strip_comb t)) of SOME (c, _) => insert (op =) c | NONE => I) + | predname_of _ = I + val relevant_prednames = fold (fn (_, clauses') => + fold (fn (_, ps) => fold Term.add_const_names (map dest_indprem ps)) clauses') clauses [] + |> filter_out (fn name => member (op =) prednames name) + val extra_modes = + if #infer_pos_and_neg_modes mode_analysis_options then + let + val pos_extra_modes = + map_filter (fn name => Option.map (pair name) (try lookup_mode name)) + relevant_prednames + val neg_extra_modes = + map_filter (fn name => Option.map (pair name) (try lookup_neg_mode name)) + relevant_prednames + in + map (fn (s, ms) => (s, (add_polarity_and_random_bit s true ms) + @ add_polarity_and_random_bit s false (the (AList.lookup (op =) neg_extra_modes s)))) + pos_extra_modes + end + else + map (fn (s, ms) => (s, (add_polarity_and_random_bit s true ms))) + (map_filter (fn name => Option.map (pair name) (try lookup_mode name)) + relevant_prednames) + val _ = print_extra_modes options extra_modes + val start_modes = + if #infer_pos_and_neg_modes mode_analysis_options then + map (fn (s, ms) => (s, map (fn m => ((true, m), false)) ms @ + (map (fn m => ((false, m), false)) ms))) all_modes + else + map (fn (s, ms) => (s, map (fn m => ((true, m), false)) ms)) all_modes + fun iteration modes = map + (check_modes_pred' mode_analysis_options options ctxt param_vs clauses + (modes @ extra_modes)) modes + val ((modes : (string * ((bool * mode) * bool) list) list), errors) = + Output.cond_timeit false "Fixpount computation of mode analysis" (fn () => + if show_invalid_clauses options then + fixp_with_state (fn (modes, errors) => + let + val (modes', new_errors) = split_list (iteration modes) + in (modes', errors @ flat new_errors) end) (start_modes, []) + else + (fixp (fn modes => map fst (iteration modes)) start_modes, [])) + val moded_clauses = map (get_modes_pred' mode_analysis_options ctxt param_vs clauses + (modes @ extra_modes)) modes + val need_random = fold (fn (s, ms) => if member (op =) (map fst preds) s then + cons (s, (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms)) else I) + modes [] + in + ((moded_clauses, need_random), errors) + end; + +end; \ No newline at end of file diff -r b6acda4d1c29 -r ea46574ca815 src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Thu Oct 21 19:13:09 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Thu Oct 21 19:13:09 2010 +0200 @@ -150,6 +150,7 @@ val imp_prems_conv : conv -> conv (* simple transformations *) val split_conjuncts_in_assms : Proof.context -> thm -> thm + val dest_conjunct_prem : thm -> thm list val expand_tuples : theory -> thm -> thm val case_betapply : theory -> term -> term val eta_contract_ho_arguments : theory -> thm -> thm @@ -157,7 +158,7 @@ val remove_pointless_clauses : thm -> thm list val peephole_optimisation : theory -> thm -> thm option val define_quickcheck_predicate : - term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory + term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory end; structure Predicate_Compile_Aux : PREDICATE_COMPILE_AUX = @@ -854,7 +855,14 @@ in singleton (Variable.export ctxt' ctxt) (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th) end - + +fun dest_conjunct_prem th = + case HOLogic.dest_Trueprop (prop_of th) of + (Const (@{const_name HOL.conj}, _) $ t $ t') => + dest_conjunct_prem (th RS @{thm conjunct1}) + @ dest_conjunct_prem (th RS @{thm conjunct2}) + | _ => [th] + fun expand_tuples thy intro = let val ctxt = ProofContext.init_global thy diff -r b6acda4d1c29 -r ea46574ca815 src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Oct 21 19:13:09 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Oct 21 19:13:09 2010 +0200 @@ -78,35 +78,16 @@ type moded_clause = term list * (Predicate_Compile_Aux.indprem * mode_derivation) list type 'a pred_mode_table = (string * ((bool * mode) * 'a) list) list - val infer_modes : - mode_analysis_options -> options -> - (string -> Predicate_Compile_Aux.mode list) * (string -> Predicate_Compile_Aux.mode list) - * (string -> Predicate_Compile_Aux.mode -> bool) -> Proof.context -> (string * typ) list -> - (string * mode list) list -> - string list -> (string * (Term.term list * Predicate_Compile_Aux.indprem list) list) list -> - ((moded_clause list pred_mode_table * (string * mode list) list) * string list) + end; structure Predicate_Compile_Core : PREDICATE_COMPILE_CORE = struct open Predicate_Compile_Aux; - -(** auxiliary **) - -(* debug stuff *) - -fun print_tac options s = - if show_proof_trace options then Tactical.print_tac s else Seq.single; - -fun assert b = if not b then raise Fail "Assertion failed" else warning "Assertion holds" - -datatype assertion = Max_number_of_subgoals of int -fun assert_tac (Max_number_of_subgoals i) st = - if (nprems_of st <= i) then Seq.single st - else raise Fail ("assert_tac: Numbers of subgoals mismatch at goal state :" - ^ "\n" ^ Pretty.string_of (Pretty.chunks - (Goal_Display.pretty_goals_without_context st))); +open Core_Data; +open Mode_Inference; +open Predicate_Compile_Proof; (** fundamentals **) @@ -149,181 +130,12 @@ val strip_intro_concl = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) -(* derivation trees for modes of premises *) - -datatype mode_derivation = Mode_App of mode_derivation * mode_derivation | Context of mode - | Mode_Pair of mode_derivation * mode_derivation | Term of mode - -fun string_of_derivation (Mode_App (m1, m2)) = - "App (" ^ string_of_derivation m1 ^ ", " ^ string_of_derivation m2 ^ ")" - | string_of_derivation (Mode_Pair (m1, m2)) = - "Pair (" ^ string_of_derivation m1 ^ ", " ^ string_of_derivation m2 ^ ")" - | string_of_derivation (Term m) = "Term (" ^ string_of_mode m ^ ")" - | string_of_derivation (Context m) = "Context (" ^ string_of_mode m ^ ")" - -fun strip_mode_derivation deriv = - let - fun strip (Mode_App (deriv1, deriv2)) ds = strip deriv1 (deriv2 :: ds) - | strip deriv ds = (deriv, ds) - in - strip deriv [] - end - -fun mode_of (Context m) = m - | mode_of (Term m) = m - | mode_of (Mode_App (d1, d2)) = - (case mode_of d1 of Fun (m, m') => - (if eq_mode (m, mode_of d2) then m' else raise Fail "mode_of: derivation has mismatching modes") - | _ => raise Fail "mode_of: derivation has a non-functional mode") - | mode_of (Mode_Pair (d1, d2)) = - Pair (mode_of d1, mode_of d2) - -fun head_mode_of deriv = mode_of (fst (strip_mode_derivation deriv)) - -fun param_derivations_of deriv = - let - val (_, argument_derivs) = strip_mode_derivation deriv - fun param_derivation (Mode_Pair (m1, m2)) = - param_derivation m1 @ param_derivation m2 - | param_derivation (Term _) = [] - | param_derivation m = [m] - in - maps param_derivation argument_derivs - end - -fun collect_context_modes (Mode_App (m1, m2)) = - collect_context_modes m1 @ collect_context_modes m2 - | collect_context_modes (Mode_Pair (m1, m2)) = - collect_context_modes m1 @ collect_context_modes m2 - | collect_context_modes (Context m) = [m] - | collect_context_modes (Term _) = [] - (* representation of inferred clauses with modes *) type moded_clause = term list * (indprem * mode_derivation) list type 'a pred_mode_table = (string * ((bool * mode) * 'a) list) list -(* book-keeping *) - -datatype predfun_data = PredfunData of { - definition : thm, - intro : thm, - elim : thm, - neg_intro : thm option -}; - -fun rep_predfun_data (PredfunData data) = data; - -fun mk_predfun_data (definition, ((intro, elim), neg_intro)) = - PredfunData {definition = definition, intro = intro, elim = elim, neg_intro = neg_intro} - -datatype pred_data = PredData of { - intros : (string option * thm) list, - elim : thm option, - preprocessed : bool, - function_names : (compilation * (mode * string) list) list, - predfun_data : (mode * predfun_data) list, - needs_random : mode list -}; - -fun rep_pred_data (PredData data) = data; - -fun mk_pred_data (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random))) = - PredData {intros = intros, elim = elim, preprocessed = preprocessed, - function_names = function_names, predfun_data = predfun_data, needs_random = needs_random} - -fun map_pred_data f (PredData {intros, elim, preprocessed, function_names, predfun_data, needs_random}) = - mk_pred_data (f (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random)))) - -fun eq_option eq (NONE, NONE) = true - | eq_option eq (SOME x, SOME y) = eq (x, y) - | eq_option eq _ = false - -fun eq_pred_data (PredData d1, PredData d2) = - eq_list (eq_pair (op =) Thm.eq_thm) (#intros d1, #intros d2) andalso - eq_option Thm.eq_thm (#elim d1, #elim d2) - -structure PredData = Theory_Data -( - type T = pred_data Graph.T; - val empty = Graph.empty; - val extend = I; - val merge = Graph.merge eq_pred_data; -); - -(* queries *) - -fun lookup_pred_data ctxt name = - Option.map rep_pred_data (try (Graph.get_node (PredData.get (ProofContext.theory_of ctxt))) name) - -fun the_pred_data ctxt name = case lookup_pred_data ctxt name - of NONE => error ("No such predicate " ^ quote name) - | SOME data => data; - -val is_registered = is_some oo lookup_pred_data - -val all_preds_of = Graph.keys o PredData.get o ProofContext.theory_of - -val intros_of = map snd o #intros oo the_pred_data - -val names_of = map fst o #intros oo the_pred_data - -fun the_elim_of ctxt name = case #elim (the_pred_data ctxt name) - of NONE => error ("No elimination rule for predicate " ^ quote name) - | SOME thm => thm - -val has_elim = is_some o #elim oo the_pred_data - -fun function_names_of compilation ctxt name = - case AList.lookup (op =) (#function_names (the_pred_data ctxt name)) compilation of - NONE => error ("No " ^ string_of_compilation compilation - ^ " functions defined for predicate " ^ quote name) - | SOME fun_names => fun_names - -fun function_name_of compilation ctxt name mode = - case AList.lookup eq_mode - (function_names_of compilation ctxt name) mode of - NONE => error ("No " ^ string_of_compilation compilation - ^ " function defined for mode " ^ string_of_mode mode ^ " of predicate " ^ quote name) - | SOME function_name => function_name - -fun modes_of compilation ctxt name = map fst (function_names_of compilation ctxt name) - -fun all_modes_of compilation ctxt = - map_filter (fn name => Option.map (pair name) (try (modes_of compilation ctxt) name)) - (all_preds_of ctxt) - -val all_random_modes_of = all_modes_of Random - -fun defined_functions compilation ctxt name = case lookup_pred_data ctxt name of - NONE => false - | SOME data => AList.defined (op =) (#function_names data) compilation - -fun needs_random ctxt s m = - member (op =) (#needs_random (the_pred_data ctxt s)) m - -fun lookup_predfun_data ctxt name mode = - Option.map rep_predfun_data - (AList.lookup (op =) (#predfun_data (the_pred_data ctxt name)) mode) - -fun the_predfun_data ctxt name mode = - case lookup_predfun_data ctxt name mode of - NONE => error ("No function defined for mode " ^ string_of_mode mode ^ - " of predicate " ^ name) - | SOME data => data; - -val predfun_definition_of = #definition ooo the_predfun_data - -val predfun_intro_of = #intro ooo the_predfun_data - -val predfun_elim_of = #elim ooo the_predfun_data - -val predfun_neg_intro_of = #neg_intro ooo the_predfun_data - -val intros_graph_of = - Graph.map (K (map snd o #intros o rep_pred_data)) o PredData.get o ProofContext.theory_of - (* diagnostic display functions *) fun print_modes options modes = @@ -342,19 +154,6 @@ val _ = tracing (cat_lines (map print_pred pred_mode_table)) in () end; -fun string_of_prem ctxt (Prem t) = - (Syntax.string_of_term ctxt t) ^ "(premise)" - | string_of_prem ctxt (Negprem t) = - (Syntax.string_of_term ctxt (HOLogic.mk_not t)) ^ "(negative premise)" - | string_of_prem ctxt (Sidecond t) = - (Syntax.string_of_term ctxt t) ^ "(sidecondition)" - | string_of_prem ctxt _ = raise Fail "string_of_prem: unexpected input" - -fun string_of_clause ctxt pred (ts, prems) = - (space_implode " --> " - (map (string_of_prem ctxt) prems)) ^ " --> " ^ pred ^ " " - ^ (space_implode " " (map (Syntax.string_of_term ctxt) ts)) - fun print_compiled_terms options ctxt = if show_compilation options then print_pred_mode_table (fn _ => fn _ => Syntax.string_of_term ctxt) @@ -585,13 +384,6 @@ val cases = map mk_case intros in Logic.list_implies (assm :: cases, prop) end; -fun dest_conjunct_prem th = - case HOLogic.dest_Trueprop (prop_of th) of - (Const (@{const_name HOL.conj}, _) $ t $ t') => - dest_conjunct_prem (th RS @{thm conjunct1}) - @ dest_conjunct_prem (th RS @{thm conjunct2}) - | _ => [th] - fun prove_casesrule ctxt (pred, (pre_cases_rule, nparams)) cases_rule = let val thy = ProofContext.theory_of ctxt @@ -755,58 +547,6 @@ PredData.map (Graph.map_node name (map_pred_data set)) end -(* registration of alternative function names *) - -structure Alt_Compilations_Data = Theory_Data -( - type T = (mode * (compilation_funs -> typ -> term)) list Symtab.table; - val empty = Symtab.empty; - val extend = I; - fun merge data : T = Symtab.merge (K true) data; -); - -fun alternative_compilation_of_global thy pred_name mode = - AList.lookup eq_mode (Symtab.lookup_list (Alt_Compilations_Data.get thy) pred_name) mode - -fun alternative_compilation_of ctxt pred_name mode = - AList.lookup eq_mode - (Symtab.lookup_list (Alt_Compilations_Data.get (ProofContext.theory_of ctxt)) pred_name) mode - -fun force_modes_and_compilations pred_name compilations = - let - (* thm refl is a dummy thm *) - val modes = map fst compilations - val (needs_random, non_random_modes) = pairself (map fst) - (List.partition (fn (m, (fun_name, random)) => random) compilations) - val non_random_dummys = map (rpair "dummy") non_random_modes - val all_dummys = map (rpair "dummy") modes - val dummy_function_names = map (rpair all_dummys) Predicate_Compile_Aux.random_compilations - @ map (rpair non_random_dummys) Predicate_Compile_Aux.non_random_compilations - val alt_compilations = map (apsnd fst) compilations - in - PredData.map (Graph.new_node - (pred_name, mk_pred_data ((([], SOME @{thm refl}), true), (dummy_function_names, ([], needs_random))))) - #> Alt_Compilations_Data.map (Symtab.insert (K false) (pred_name, alt_compilations)) - end - -fun functional_compilation fun_name mode compfuns T = - let - val (inpTs, outpTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) - mode (binder_types T) - val bs = map (pair "x") inpTs - val bounds = map Bound (rev (0 upto (length bs) - 1)) - val f = Const (fun_name, inpTs ---> HOLogic.mk_tupleT outpTs) - in list_abs (bs, mk_single compfuns (list_comb (f, bounds))) end - -fun register_alternative_function pred_name mode fun_name = - Alt_Compilations_Data.map (Symtab.insert_list (eq_pair eq_mode (K false)) - (pred_name, (mode, functional_compilation fun_name mode))) - -fun force_modes_and_functions pred_name fun_names = - force_modes_and_compilations pred_name - (map (fn (mode, (fun_name, random)) => (mode, (functional_compilation fun_name mode, random))) - fun_names) - (* compilation modifiers *) structure Comp_Mod = @@ -1133,524 +873,6 @@ | Neg_Generator_DSeq => pos_generator_dseq_comp_modifiers | c => comp_modifiers) -(** mode analysis **) - -type mode_analysis_options = - {use_generators : bool, - reorder_premises : bool, - infer_pos_and_neg_modes : bool} - -fun is_constrt thy = - let - val cnstrs = flat (maps - (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd) - (Symtab.dest (Datatype.get_all thy))); - fun check t = (case strip_comb t of - (Free _, []) => true - | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of - (SOME (i, Tname), Type (Tname', _)) => - length ts = i andalso Tname = Tname' andalso forall check ts - | _ => false) - | _ => false) - in check end; - -(*** check if a type is an equality type (i.e. doesn't contain fun) - FIXME this is only an approximation ***) -fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts - | is_eqT _ = true; - -fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm []; -val terms_vs = distinct (op =) o maps term_vs; - -(** collect all Frees in a term (with duplicates!) **) -fun term_vTs tm = - fold_aterms (fn Free xT => cons xT | _ => I) tm []; - -fun subsets i j = - if i <= j then - let - fun merge xs [] = xs - | merge [] ys = ys - | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys) - else y::merge (x::xs) ys; - val is = subsets (i+1) j - in merge (map (fn ks => i::ks) is) is end - else [[]]; - -fun print_failed_mode options thy modes p (pol, m) rs is = - if show_mode_inference options then - let - val _ = tracing ("Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^ - p ^ " violates mode " ^ string_of_mode m) - in () end - else () - -fun error_of p (pol, m) is = - " Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^ - p ^ " violates mode " ^ string_of_mode m - -fun is_all_input mode = - let - fun is_all_input' (Fun _) = true - | is_all_input' (Pair (m1, m2)) = is_all_input' m1 andalso is_all_input' m2 - | is_all_input' Input = true - | is_all_input' Output = false - in - forall is_all_input' (strip_fun_mode mode) - end - -fun all_input_of T = - let - val (Ts, U) = strip_type T - fun input_of (Type (@{type_name Product_Type.prod}, [T1, T2])) = Pair (input_of T1, input_of T2) - | input_of _ = Input - in - if U = HOLogic.boolT then - fold_rev (curry Fun) (map input_of Ts) Bool - else - raise Fail "all_input_of: not a predicate" - end - -fun find_least ord xs = - let - fun find' x y = (case y of NONE => SOME x | SOME y' => if ord (x, y') = LESS then SOME x else y) - in - fold find' xs NONE - end - -fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm []; -val terms_vs = distinct (op =) o maps term_vs; - -fun input_mode T = - let - val (Ts, U) = strip_type T - in - fold_rev (curry Fun) (map (K Input) Ts) Input - end - -fun output_mode T = - let - val (Ts, U) = strip_type T - in - fold_rev (curry Fun) (map (K Output) Ts) Output - end - -fun is_invertible_function ctxt (Const (f, _)) = is_constr ctxt f - | is_invertible_function ctxt _ = false - -fun non_invertible_subterms ctxt (t as Free _) = [] - | non_invertible_subterms ctxt t = - let - val (f, args) = strip_comb t - in - if is_invertible_function ctxt f then - maps (non_invertible_subterms ctxt) args - else - [t] - end - -fun collect_non_invertible_subterms ctxt (f as Free _) (names, eqs) = (f, (names, eqs)) - | collect_non_invertible_subterms ctxt t (names, eqs) = - case (strip_comb t) of (f, args) => - if is_invertible_function ctxt f then - let - val (args', (names', eqs')) = - fold_map (collect_non_invertible_subterms ctxt) args (names, eqs) - in - (list_comb (f, args'), (names', eqs')) - end - else - let - val s = Name.variant names "x" - val v = Free (s, fastype_of t) - in - (v, (s :: names, HOLogic.mk_eq (v, t) :: eqs)) - end -(* - if is_constrt thy t then (t, (names, eqs)) else - let - val s = Name.variant names "x" - val v = Free (s, fastype_of t) - in (v, (s::names, HOLogic.mk_eq (v, t)::eqs)) end; -*) - -fun is_possible_output ctxt vs t = - forall - (fn t => is_eqT (fastype_of t) andalso forall (member (op =) vs) (term_vs t)) - (non_invertible_subterms ctxt t) - andalso - (forall (is_eqT o snd) - (inter (fn ((f', _), f) => f = f') vs (Term.add_frees t []))) - -fun vars_of_destructable_term ctxt (Free (x, _)) = [x] - | vars_of_destructable_term ctxt t = - let - val (f, args) = strip_comb t - in - if is_invertible_function ctxt f then - maps (vars_of_destructable_term ctxt) args - else - [] - end - -fun is_constructable vs t = forall (member (op =) vs) (term_vs t) - -fun missing_vars vs t = subtract (op =) vs (term_vs t) - -fun output_terms (Const (@{const_name Pair}, _) $ t1 $ t2, Mode_Pair (d1, d2)) = - output_terms (t1, d1) @ output_terms (t2, d2) - | output_terms (t1 $ t2, Mode_App (d1, d2)) = - output_terms (t1, d1) @ output_terms (t2, d2) - | output_terms (t, Term Output) = [t] - | output_terms _ = [] - -fun lookup_mode modes (Const (s, T)) = - (case (AList.lookup (op =) modes s) of - SOME ms => SOME (map (fn m => (Context m, [])) ms) - | NONE => NONE) - | lookup_mode modes (Free (x, _)) = - (case (AList.lookup (op =) modes x) of - SOME ms => SOME (map (fn m => (Context m , [])) ms) - | NONE => NONE) - -fun derivations_of (ctxt : Proof.context) modes vs (Const (@{const_name Pair}, _) $ t1 $ t2) (Pair (m1, m2)) = - map_product - (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2)) - (derivations_of ctxt modes vs t1 m1) (derivations_of ctxt modes vs t2 m2) - | derivations_of ctxt modes vs t (m as Fun _) = - (*let - val (p, args) = strip_comb t - in - (case lookup_mode modes p of - SOME ms => map_filter (fn (Context m, []) => let - val ms = strip_fun_mode m - val (argms, restms) = chop (length args) ms - val m' = fold_rev (curry Fun) restms Bool - in - if forall (fn m => eq_mode (Input, m)) argms andalso eq_mode (m', mode) then - SOME (fold (curry Mode_App) (map Term argms) (Context m), missing_vars vs t) - else NONE - end) ms - | NONE => (if is_all_input mode then [(Context mode, [])] else [])) - end*) - (case try (all_derivations_of ctxt modes vs) t of - SOME derivs => - filter (fn (d, mvars) => eq_mode (mode_of d, m) andalso null (output_terms (t, d))) derivs - | NONE => (if is_all_input m then [(Context m, [])] else [])) - | derivations_of ctxt modes vs t m = - if eq_mode (m, Input) then - [(Term Input, missing_vars vs t)] - else if eq_mode (m, Output) then - (if is_possible_output ctxt vs t then [(Term Output, [])] else []) - else [] -and all_derivations_of ctxt modes vs (Const (@{const_name Pair}, _) $ t1 $ t2) = - let - val derivs1 = all_derivations_of ctxt modes vs t1 - val derivs2 = all_derivations_of ctxt modes vs t2 - in - map_product - (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2)) - derivs1 derivs2 - end - | all_derivations_of ctxt modes vs (t1 $ t2) = - let - val derivs1 = all_derivations_of ctxt modes vs t1 - in - maps (fn (d1, mvars1) => - case mode_of d1 of - Fun (m', _) => map (fn (d2, mvars2) => - (Mode_App (d1, d2), union (op =) mvars1 mvars2)) (derivations_of ctxt modes vs t2 m') - | _ => raise Fail "all_derivations_of: derivation has an unexpected non-functional mode") derivs1 - end - | all_derivations_of _ modes vs (Const (s, T)) = the (lookup_mode modes (Const (s, T))) - | all_derivations_of _ modes vs (Free (x, T)) = the (lookup_mode modes (Free (x, T))) - | all_derivations_of _ modes vs _ = raise Fail "all_derivations_of: unexpected term" - -fun rev_option_ord ord (NONE, NONE) = EQUAL - | rev_option_ord ord (NONE, SOME _) = GREATER - | rev_option_ord ord (SOME _, NONE) = LESS - | rev_option_ord ord (SOME x, SOME y) = ord (x, y) - -fun random_mode_in_deriv modes t deriv = - case try dest_Const (fst (strip_comb t)) of - SOME (s, _) => - (case AList.lookup (op =) modes s of - SOME ms => - (case AList.lookup (op =) (map (fn ((p, m), r) => (m, r)) ms) (head_mode_of deriv) of - SOME r => r - | NONE => false) - | NONE => false) - | NONE => false - -fun number_of_output_positions mode = - let - val args = strip_fun_mode mode - fun contains_output (Fun _) = false - | contains_output Input = false - | contains_output Output = true - | contains_output (Pair (m1, m2)) = contains_output m1 orelse contains_output m2 - in - length (filter contains_output args) - end - -fun lex_ord ord1 ord2 (x, x') = - case ord1 (x, x') of - EQUAL => ord2 (x, x') - | ord => ord - -fun lexl_ord [] (x, x') = EQUAL - | lexl_ord (ord :: ords') (x, x') = - case ord (x, x') of - EQUAL => lexl_ord ords' (x, x') - | ord => ord - -fun deriv_ord' ctxt pol pred modes t1 t2 ((deriv1, mvars1), (deriv2, mvars2)) = - let - (* prefer functional modes if it is a function *) - fun fun_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = - let - fun is_functional t mode = - case try (fst o dest_Const o fst o strip_comb) t of - NONE => false - | SOME c => is_some (alternative_compilation_of ctxt c mode) - in - case (is_functional t1 (head_mode_of deriv1), is_functional t2 (head_mode_of deriv2)) of - (true, true) => EQUAL - | (true, false) => LESS - | (false, true) => GREATER - | (false, false) => EQUAL - end - (* prefer modes without requirement for generating random values *) - fun mvars_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = - int_ord (length mvars1, length mvars2) - (* prefer non-random modes *) - fun random_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = - int_ord (if random_mode_in_deriv modes t1 deriv1 then 1 else 0, - if random_mode_in_deriv modes t2 deriv2 then 1 else 0) - (* prefer modes with more input and less output *) - fun output_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = - int_ord (number_of_output_positions (head_mode_of deriv1), - number_of_output_positions (head_mode_of deriv2)) - (* prefer recursive calls *) - fun is_rec_premise t = - case fst (strip_comb t) of Const (c, T) => c = pred | _ => false - fun recursive_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) = - int_ord (if is_rec_premise t1 then 0 else 1, - if is_rec_premise t2 then 0 else 1) - val ord = lexl_ord [mvars_ord, fun_ord, random_mode_ord, output_mode_ord, recursive_ord] - in - ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) - end - -fun deriv_ord ctxt pol pred modes t = deriv_ord' ctxt pol pred modes t t - -fun premise_ord thy pol pred modes ((prem1, a1), (prem2, a2)) = - rev_option_ord (deriv_ord' thy pol pred modes (dest_indprem prem1) (dest_indprem prem2)) (a1, a2) - -fun print_mode_list modes = - tracing ("modes: " ^ (commas (map (fn (s, ms) => s ^ ": " ^ - commas (map (fn (m, r) => string_of_mode m ^ (if r then " random " else " not ")) ms)) modes))) - -fun select_mode_prem (mode_analysis_options : mode_analysis_options) (ctxt : Proof.context) pred - pol (modes, (pos_modes, neg_modes)) vs ps = - let - fun choose_mode_of_prem (Prem t) = - find_least (deriv_ord ctxt pol pred modes t) (all_derivations_of ctxt pos_modes vs t) - | choose_mode_of_prem (Sidecond t) = SOME (Context Bool, missing_vars vs t) - | choose_mode_of_prem (Negprem t) = find_least (deriv_ord ctxt (not pol) pred modes t) - (filter (fn (d, missing_vars) => is_all_input (head_mode_of d)) - (all_derivations_of ctxt neg_modes vs t)) - | choose_mode_of_prem p = raise Fail ("choose_mode_of_prem: unexpected premise " ^ string_of_prem ctxt p) - in - if #reorder_premises mode_analysis_options then - find_least (premise_ord ctxt pol pred modes) (ps ~~ map choose_mode_of_prem ps) - else - SOME (hd ps, choose_mode_of_prem (hd ps)) - end - -fun check_mode_clause' (mode_analysis_options : mode_analysis_options) ctxt pred param_vs (modes : - (string * ((bool * mode) * bool) list) list) ((pol, mode) : bool * mode) (ts, ps) = - let - val vTs = distinct (op =) (fold Term.add_frees (map dest_indprem ps) (fold Term.add_frees ts [])) - val modes' = modes @ (param_vs ~~ map (fn x => [((true, x), false), ((false, x), false)]) (ho_arg_modes_of mode)) - fun retrieve_modes_of_pol pol = map (fn (s, ms) => - (s, map_filter (fn ((p, m), r) => if p = pol then SOME m else NONE | _ => NONE) ms)) - val (pos_modes', neg_modes') = - if #infer_pos_and_neg_modes mode_analysis_options then - (retrieve_modes_of_pol pol modes', retrieve_modes_of_pol (not pol) modes') - else - let - val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes' - in (modes, modes) end - val (in_ts, out_ts) = split_mode mode ts - val in_vs = union (op =) param_vs (maps (vars_of_destructable_term ctxt) in_ts) - val out_vs = terms_vs out_ts - fun known_vs_after p vs = (case p of - Prem t => union (op =) vs (term_vs t) - | Sidecond t => union (op =) vs (term_vs t) - | Negprem t => union (op =) vs (term_vs t) - | _ => raise Fail "unexpected premise") - fun check_mode_prems acc_ps rnd vs [] = SOME (acc_ps, vs, rnd) - | check_mode_prems acc_ps rnd vs ps = - (case - (select_mode_prem mode_analysis_options ctxt pred pol (modes', (pos_modes', neg_modes')) vs ps) of - SOME (p, SOME (deriv, [])) => check_mode_prems ((p, deriv) :: acc_ps) rnd - (known_vs_after p vs) (filter_out (equal p) ps) - | SOME (p, SOME (deriv, missing_vars)) => - if #use_generators mode_analysis_options andalso pol then - check_mode_prems ((p, deriv) :: (map - (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output)) - (distinct (op =) missing_vars)) - @ acc_ps) true (known_vs_after p vs) (filter_out (equal p) ps) - else NONE - | SOME (p, NONE) => NONE - | NONE => NONE) - in - case check_mode_prems [] false in_vs ps of - NONE => NONE - | SOME (acc_ps, vs, rnd) => - if forall (is_constructable vs) (in_ts @ out_ts) then - SOME (ts, rev acc_ps, rnd) - else - if #use_generators mode_analysis_options andalso pol then - let - val generators = map - (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output)) - (subtract (op =) vs (terms_vs (in_ts @ out_ts))) - in - SOME (ts, rev (generators @ acc_ps), true) - end - else - NONE - end - -datatype result = Success of bool | Error of string - -fun check_modes_pred' mode_analysis_options options thy param_vs clauses modes (p, (ms : ((bool * mode) * bool) list)) = - let - fun split xs = - let - fun split' [] (ys, zs) = (rev ys, rev zs) - | split' ((m, Error z) :: xs) (ys, zs) = split' xs (ys, z :: zs) - | split' (((m : bool * mode), Success rnd) :: xs) (ys, zs) = split' xs ((m, rnd) :: ys, zs) - in - split' xs ([], []) - end - val rs = these (AList.lookup (op =) clauses p) - fun check_mode m = - let - val res = Output.cond_timeit false "work part of check_mode for one mode" (fn _ => - map (check_mode_clause' mode_analysis_options thy p param_vs modes m) rs) - in - Output.cond_timeit false "aux part of check_mode for one mode" (fn _ => - case find_indices is_none res of - [] => Success (exists (fn SOME (_, _, true) => true | _ => false) res) - | is => (print_failed_mode options thy modes p m rs is; Error (error_of p m is))) - end - val _ = if show_mode_inference options then - tracing ("checking " ^ string_of_int (length ms) ^ " modes ...") - else () - val res = Output.cond_timeit false "check_mode" (fn _ => map (fn (m, _) => (m, check_mode m)) ms) - val (ms', errors) = split res - in - ((p, (ms' : ((bool * mode) * bool) list)), errors) - end; - -fun get_modes_pred' mode_analysis_options thy param_vs clauses modes (p, ms) = - let - val rs = these (AList.lookup (op =) clauses p) - in - (p, map (fn (m, rnd) => - (m, map - ((fn (ts, ps, rnd) => (ts, ps)) o the o - check_mode_clause' mode_analysis_options thy p param_vs modes m) rs)) ms) - end; - -fun fixp f (x : (string * ((bool * mode) * bool) list) list) = - let val y = f x - in if x = y then x else fixp f y end; - -fun fixp_with_state f (x : (string * ((bool * mode) * bool) list) list, state) = - let - val (y, state') = f (x, state) - in - if x = y then (y, state') else fixp_with_state f (y, state') - end - -fun string_of_ext_mode ((pol, mode), rnd) = - string_of_mode mode ^ "(" ^ (if pol then "pos" else "neg") ^ ", " - ^ (if rnd then "rnd" else "nornd") ^ ")" - -fun print_extra_modes options modes = - if show_mode_inference options then - tracing ("Modes of inferred predicates: " ^ - cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map string_of_ext_mode ms)) modes)) - else () - -fun infer_modes mode_analysis_options options (lookup_mode, lookup_neg_mode, needs_random) ctxt - preds all_modes param_vs clauses = - let - fun appair f (x1, x2) (y1, y2) = (f x1 y1, f x2 y2) - fun add_needs_random s (false, m) = ((false, m), false) - | add_needs_random s (true, m) = ((true, m), needs_random s m) - fun add_polarity_and_random_bit s b ms = map (fn m => add_needs_random s (b, m)) ms - val prednames = map fst preds - (* extramodes contains all modes of all constants, should we only use the necessary ones - - what is the impact on performance? *) - fun predname_of (Prem t) = - (case try dest_Const (fst (strip_comb t)) of SOME (c, _) => insert (op =) c | NONE => I) - | predname_of (Negprem t) = - (case try dest_Const (fst (strip_comb t)) of SOME (c, _) => insert (op =) c | NONE => I) - | predname_of _ = I - val relevant_prednames = fold (fn (_, clauses') => - fold (fn (_, ps) => fold Term.add_const_names (map dest_indprem ps)) clauses') clauses [] - |> filter_out (fn name => member (op =) prednames name) - val extra_modes = - if #infer_pos_and_neg_modes mode_analysis_options then - let - val pos_extra_modes = - map_filter (fn name => Option.map (pair name) (try lookup_mode name)) - relevant_prednames - val neg_extra_modes = - map_filter (fn name => Option.map (pair name) (try lookup_neg_mode name)) - relevant_prednames - in - map (fn (s, ms) => (s, (add_polarity_and_random_bit s true ms) - @ add_polarity_and_random_bit s false (the (AList.lookup (op =) neg_extra_modes s)))) - pos_extra_modes - end - else - map (fn (s, ms) => (s, (add_polarity_and_random_bit s true ms))) - (map_filter (fn name => Option.map (pair name) (try lookup_mode name)) - relevant_prednames) - val _ = print_extra_modes options extra_modes - val start_modes = - if #infer_pos_and_neg_modes mode_analysis_options then - map (fn (s, ms) => (s, map (fn m => ((true, m), false)) ms @ - (map (fn m => ((false, m), false)) ms))) all_modes - else - map (fn (s, ms) => (s, map (fn m => ((true, m), false)) ms)) all_modes - fun iteration modes = map - (check_modes_pred' mode_analysis_options options ctxt param_vs clauses - (modes @ extra_modes)) modes - val ((modes : (string * ((bool * mode) * bool) list) list), errors) = - Output.cond_timeit false "Fixpount computation of mode analysis" (fn () => - if show_invalid_clauses options then - fixp_with_state (fn (modes, errors) => - let - val (modes', new_errors) = split_list (iteration modes) - in (modes', errors @ flat new_errors) end) (start_modes, []) - else - (fixp (fn modes => map fst (iteration modes)) start_modes, [])) - val moded_clauses = map (get_modes_pred' mode_analysis_options ctxt param_vs clauses - (modes @ extra_modes)) modes - val need_random = fold (fn (s, ms) => if member (op =) (map fst preds) s then - cons (s, (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms)) else I) - modes [] - in - ((moded_clauses, need_random), errors) - end; - (* term construction *) fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of @@ -2088,11 +1310,6 @@ (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation)) end; -(** special setup for simpset **) -val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}]) - setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac)) - setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI})) - (* Definition of executable functions and their intro and elim rules *) fun print_arities arities = tracing ("Arities:\n" ^ @@ -2256,483 +1473,6 @@ |> fold create_definition modes end; -(* Proving equivalence of term *) - -fun is_Type (Type _) = true - | is_Type _ = false - -(* returns true if t is an application of an datatype constructor *) -(* which then consequently would be splitted *) -(* else false *) -fun is_constructor thy t = - if (is_Type (fastype_of t)) then - (case Datatype.get_info thy ((fst o dest_Type o fastype_of) t) of - NONE => false - | SOME info => (let - val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info) - val (c, _) = strip_comb t - in (case c of - Const (name, _) => member (op =) constr_consts name - | _ => false) end)) - else false - -(* MAJOR FIXME: prove_params should be simple - - different form of introrule for parameters ? *) - -fun prove_param options ctxt nargs t deriv = - let - val (f, args) = strip_comb (Envir.eta_contract t) - val mode = head_mode_of deriv - val param_derivations = param_derivations_of deriv - val ho_args = ho_args_of mode args - val f_tac = case f of - Const (name, T) => simp_tac (HOL_basic_ss addsimps - [@{thm eval_pred}, predfun_definition_of ctxt name mode, - @{thm split_eta}, @{thm split_beta}, @{thm fst_conv}, - @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1 - | Free _ => - Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems, asms, concl, schematics} => - let - val prems' = maps dest_conjunct_prem (take nargs prems) - in - MetaSimplifier.rewrite_goal_tac - (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1 - end) ctxt 1 - | Abs _ => raise Fail "prove_param: No valid parameter term" - in - REPEAT_DETERM (rtac @{thm ext} 1) - THEN print_tac options "prove_param" - THEN f_tac - THEN print_tac options "after prove_param" - THEN (REPEAT_DETERM (atac 1)) - THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations)) - THEN REPEAT_DETERM (rtac @{thm refl} 1) - end - -fun prove_expr options ctxt nargs (premposition : int) (t, deriv) = - case strip_comb t of - (Const (name, T), args) => - let - val mode = head_mode_of deriv - val introrule = predfun_intro_of ctxt name mode - val param_derivations = param_derivations_of deriv - val ho_args = ho_args_of mode args - in - print_tac options "before intro rule:" - THEN rtac introrule 1 - THEN print_tac options "after intro rule" - (* for the right assumption in first position *) - THEN rotate_tac premposition 1 - THEN atac 1 - THEN print_tac options "parameter goal" - (* work with parameter arguments *) - THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations)) - THEN (REPEAT_DETERM (atac 1)) - end - | (Free _, _) => - print_tac options "proving parameter call.." - THEN Subgoal.FOCUS_PREMS (fn {context = ctxt, params, prems, asms, concl, schematics} => - let - val param_prem = nth prems premposition - val (param, _) = strip_comb (HOLogic.dest_Trueprop (prop_of param_prem)) - val prems' = maps dest_conjunct_prem (take nargs prems) - fun param_rewrite prem = - param = snd (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of prem))) - val SOME rew_eq = find_first param_rewrite prems' - val param_prem' = MetaSimplifier.rewrite_rule - (map (fn th => th RS @{thm eq_reflection}) - [rew_eq RS @{thm sym}, @{thm split_beta}, @{thm fst_conv}, @{thm snd_conv}]) - param_prem - in - rtac param_prem' 1 - end) ctxt 1 - THEN print_tac options "after prove parameter call" - -fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; - -fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st - -fun check_format ctxt st = - let - val concl' = Logic.strip_assums_concl (hd (prems_of st)) - val concl = HOLogic.dest_Trueprop concl' - val expr = fst (strip_comb (fst (PredicateCompFuns.dest_Eval concl))) - fun valid_expr (Const (@{const_name Predicate.bind}, _)) = true - | valid_expr (Const (@{const_name Predicate.single}, _)) = true - | valid_expr _ = false - in - if valid_expr expr then - ((*tracing "expression is valid";*) Seq.single st) - else - ((*tracing "expression is not valid";*) Seq.empty) (* error "check_format: wrong format" *) - end - -fun prove_match options ctxt nargs out_ts = - let - val thy = ProofContext.theory_of ctxt - val eval_if_P = - @{lemma "P ==> Predicate.eval x z ==> Predicate.eval (if P then x else y) z" by simp} - fun get_case_rewrite t = - if (is_constructor thy t) then let - val case_rewrites = (#case_rewrites (Datatype.the_info thy - ((fst o dest_Type o fastype_of) t))) - in fold (union Thm.eq_thm) (case_rewrites :: map get_case_rewrite (snd (strip_comb t))) [] end - else [] - val simprules = insert Thm.eq_thm @{thm "unit.cases"} (insert Thm.eq_thm @{thm "prod.cases"} - (fold (union Thm.eq_thm) (map get_case_rewrite out_ts) [])) - (* replace TRY by determining if it necessary - are there equations when calling compile match? *) - in - (* make this simpset better! *) - asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1 - THEN print_tac options "after prove_match:" - THEN (DETERM (TRY - (rtac eval_if_P 1 - THEN (SUBPROOF (fn {context = ctxt, params, prems, asms, concl, schematics} => - (REPEAT_DETERM (rtac @{thm conjI} 1 - THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1)))) - THEN print_tac options "if condition to be solved:" - THEN asm_simp_tac HOL_basic_ss' 1 - THEN TRY ( - let - val prems' = maps dest_conjunct_prem (take nargs prems) - in - MetaSimplifier.rewrite_goal_tac - (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1 - end - THEN REPEAT_DETERM (rtac @{thm refl} 1)) - THEN print_tac options "after if simp; in SUBPROOF") ctxt 1)))) - THEN print_tac options "after if simplification" - end; - -(* corresponds to compile_fun -- maybe call that also compile_sidecond? *) - -fun prove_sidecond ctxt t = - let - fun preds_of t nameTs = case strip_comb t of - (f as Const (name, T), args) => - if is_registered ctxt name then (name, T) :: nameTs - else fold preds_of args nameTs - | _ => nameTs - val preds = preds_of t [] - val defs = map - (fn (pred, T) => predfun_definition_of ctxt pred - (all_input_of T)) - preds - in - simp_tac (HOL_basic_ss addsimps - (@{thms HOL.simp_thms} @ (@{thm eval_pred} :: defs))) 1 - (* need better control here! *) - end - -fun prove_clause options ctxt nargs mode (_, clauses) (ts, moded_ps) = - let - val (in_ts, clause_out_ts) = split_mode mode ts; - fun prove_prems out_ts [] = - (prove_match options ctxt nargs out_ts) - THEN print_tac options "before simplifying assumptions" - THEN asm_full_simp_tac HOL_basic_ss' 1 - THEN print_tac options "before single intro rule" - THEN Subgoal.FOCUS_PREMS - (fn {context = ctxt, params = params, prems, asms, concl, schematics} => - let - val prems' = maps dest_conjunct_prem (take nargs prems) - in - MetaSimplifier.rewrite_goal_tac - (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1 - end) ctxt 1 - THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1) - | prove_prems out_ts ((p, deriv) :: ps) = - let - val premposition = (find_index (equal p) clauses) + nargs - val mode = head_mode_of deriv - val rest_tac = - rtac @{thm bindI} 1 - THEN (case p of Prem t => - let - val (_, us) = strip_comb t - val (_, out_ts''') = split_mode mode us - val rec_tac = prove_prems out_ts''' ps - in - print_tac options "before clause:" - (*THEN asm_simp_tac HOL_basic_ss 1*) - THEN print_tac options "before prove_expr:" - THEN prove_expr options ctxt nargs premposition (t, deriv) - THEN print_tac options "after prove_expr:" - THEN rec_tac - end - | Negprem t => - let - val (t, args) = strip_comb t - val (_, out_ts''') = split_mode mode args - val rec_tac = prove_prems out_ts''' ps - val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) - val neg_intro_rule = - Option.map (fn name => - the (predfun_neg_intro_of ctxt name mode)) name - val param_derivations = param_derivations_of deriv - val params = ho_args_of mode args - in - print_tac options "before prove_neg_expr:" - THEN full_simp_tac (HOL_basic_ss addsimps - [@{thm split_eta}, @{thm split_beta}, @{thm fst_conv}, - @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1 - THEN (if (is_some name) then - print_tac options "before applying not introduction rule" - THEN Subgoal.FOCUS_PREMS - (fn {context = ctxt, params = params, prems, asms, concl, schematics} => - rtac (the neg_intro_rule) 1 - THEN rtac (nth prems premposition) 1) ctxt 1 - THEN print_tac options "after applying not introduction rule" - THEN (EVERY (map2 (prove_param options ctxt nargs) params param_derivations)) - THEN (REPEAT_DETERM (atac 1)) - else - rtac @{thm not_predI'} 1 - (* test: *) - THEN dtac @{thm sym} 1 - THEN asm_full_simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1) - THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1 - THEN rec_tac - end - | Sidecond t => - rtac @{thm if_predI} 1 - THEN print_tac options "before sidecond:" - THEN prove_sidecond ctxt t - THEN print_tac options "after sidecond:" - THEN prove_prems [] ps) - in (prove_match options ctxt nargs out_ts) - THEN rest_tac - end; - val prems_tac = prove_prems in_ts moded_ps - in - print_tac options "Proving clause..." - THEN rtac @{thm bindI} 1 - THEN rtac @{thm singleI} 1 - THEN prems_tac - end; - -fun select_sup 1 1 = [] - | select_sup _ 1 = [rtac @{thm supI1}] - | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1)); - -fun prove_one_direction options ctxt clauses preds pred mode moded_clauses = - let - val T = the (AList.lookup (op =) preds pred) - val nargs = length (binder_types T) - val pred_case_rule = the_elim_of ctxt pred - in - REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})) - THEN print_tac options "before applying elim rule" - THEN etac (predfun_elim_of ctxt pred mode) 1 - THEN etac pred_case_rule 1 - THEN print_tac options "after applying elim rule" - THEN (EVERY (map - (fn i => EVERY' (select_sup (length moded_clauses) i) i) - (1 upto (length moded_clauses)))) - THEN (EVERY (map2 (prove_clause options ctxt nargs mode) clauses moded_clauses)) - THEN print_tac options "proved one direction" - end; - -(** Proof in the other direction **) - -fun prove_match2 options ctxt out_ts = - let - val thy = ProofContext.theory_of ctxt - fun split_term_tac (Free _) = all_tac - | split_term_tac t = - if (is_constructor thy t) then - let - val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t) - val num_of_constrs = length (#case_rewrites info) - val (_, ts) = strip_comb t - in - print_tac options ("Term " ^ (Syntax.string_of_term ctxt t) ^ - "splitting with rules \n" ^ Display.string_of_thm ctxt (#split_asm info)) - THEN TRY ((Splitter.split_asm_tac [#split_asm info] 1) - THEN (print_tac options "after splitting with split_asm rules") - (* THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1) - THEN (DETERM (TRY (etac @{thm Pair_inject} 1)))*) - THEN (REPEAT_DETERM_N (num_of_constrs - 1) - (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))) - THEN (assert_tac (Max_number_of_subgoals 2)) - THEN (EVERY (map split_term_tac ts)) - end - else all_tac - in - split_term_tac (HOLogic.mk_tuple out_ts) - THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) - THEN (etac @{thm botE} 2)))) - end - -(* VERY LARGE SIMILIRATIY to function prove_param --- join both functions -*) -(* TODO: remove function *) - -fun prove_param2 options ctxt t deriv = - let - val (f, args) = strip_comb (Envir.eta_contract t) - val mode = head_mode_of deriv - val param_derivations = param_derivations_of deriv - val ho_args = ho_args_of mode args - val f_tac = case f of - Const (name, T) => full_simp_tac (HOL_basic_ss addsimps - (@{thm eval_pred}::(predfun_definition_of ctxt name mode) - :: @{thm "Product_Type.split_conv"}::[])) 1 - | Free _ => all_tac - | _ => error "prove_param2: illegal parameter term" - in - print_tac options "before simplification in prove_args:" - THEN f_tac - THEN print_tac options "after simplification in prove_args" - THEN EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations) - end - -fun prove_expr2 options ctxt (t, deriv) = - (case strip_comb t of - (Const (name, T), args) => - let - val mode = head_mode_of deriv - val param_derivations = param_derivations_of deriv - val ho_args = ho_args_of mode args - in - etac @{thm bindE} 1 - THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))) - THEN print_tac options "prove_expr2-before" - THEN etac (predfun_elim_of ctxt name mode) 1 - THEN print_tac options "prove_expr2" - THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations)) - THEN print_tac options "finished prove_expr2" - end - | _ => etac @{thm bindE} 1) - -fun prove_sidecond2 options ctxt t = let - fun preds_of t nameTs = case strip_comb t of - (f as Const (name, T), args) => - if is_registered ctxt name then (name, T) :: nameTs - else fold preds_of args nameTs - | _ => nameTs - val preds = preds_of t [] - val defs = map - (fn (pred, T) => predfun_definition_of ctxt pred - (all_input_of T)) - preds - in - (* only simplify the one assumption *) - full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 - (* need better control here! *) - THEN print_tac options "after sidecond2 simplification" - end - -fun prove_clause2 options ctxt pred mode (ts, ps) i = - let - val pred_intro_rule = nth (intros_of ctxt pred) (i - 1) - val (in_ts, clause_out_ts) = split_mode mode ts; - val split_ss = HOL_basic_ss' addsimps [@{thm split_eta}, @{thm split_beta}, - @{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}] - fun prove_prems2 out_ts [] = - print_tac options "before prove_match2 - last call:" - THEN prove_match2 options ctxt out_ts - THEN print_tac options "after prove_match2 - last call:" - THEN (etac @{thm singleE} 1) - THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1)) - THEN (asm_full_simp_tac HOL_basic_ss' 1) - THEN TRY ( - (REPEAT_DETERM (etac @{thm Pair_inject} 1)) - THEN (asm_full_simp_tac HOL_basic_ss' 1) - - THEN SOLVED (print_tac options "state before applying intro rule:" - THEN (rtac pred_intro_rule - (* How to handle equality correctly? *) - THEN_ALL_NEW (K (print_tac options "state before assumption matching") - THEN' (atac ORELSE' ((CHANGED o asm_full_simp_tac split_ss) THEN' (TRY o atac))) - THEN' (K (print_tac options "state after pre-simplification:")) - THEN' (K (print_tac options "state after assumption matching:")))) 1)) - | prove_prems2 out_ts ((p, deriv) :: ps) = - let - val mode = head_mode_of deriv - val rest_tac = (case p of - Prem t => - let - val (_, us) = strip_comb t - val (_, out_ts''') = split_mode mode us - val rec_tac = prove_prems2 out_ts''' ps - in - (prove_expr2 options ctxt (t, deriv)) THEN rec_tac - end - | Negprem t => - let - val (_, args) = strip_comb t - val (_, out_ts''') = split_mode mode args - val rec_tac = prove_prems2 out_ts''' ps - val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) - val param_derivations = param_derivations_of deriv - val ho_args = ho_args_of mode args - in - print_tac options "before neg prem 2" - THEN etac @{thm bindE} 1 - THEN (if is_some name then - full_simp_tac (HOL_basic_ss addsimps - [predfun_definition_of ctxt (the name) mode]) 1 - THEN etac @{thm not_predE} 1 - THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1 - THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations)) - else - etac @{thm not_predE'} 1) - THEN rec_tac - end - | Sidecond t => - etac @{thm bindE} 1 - THEN etac @{thm if_predE} 1 - THEN prove_sidecond2 options ctxt t - THEN prove_prems2 [] ps) - in print_tac options "before prove_match2:" - THEN prove_match2 options ctxt out_ts - THEN print_tac options "after prove_match2:" - THEN rest_tac - end; - val prems_tac = prove_prems2 in_ts ps - in - print_tac options "starting prove_clause2" - THEN etac @{thm bindE} 1 - THEN (etac @{thm singleE'} 1) - THEN (TRY (etac @{thm Pair_inject} 1)) - THEN print_tac options "after singleE':" - THEN prems_tac - end; - -fun prove_other_direction options ctxt pred mode moded_clauses = - let - fun prove_clause clause i = - (if i < length moded_clauses then etac @{thm supE} 1 else all_tac) - THEN (prove_clause2 options ctxt pred mode clause i) - in - (DETERM (TRY (rtac @{thm unit.induct} 1))) - THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all}))) - THEN (rtac (predfun_intro_of ctxt pred mode) 1) - THEN (REPEAT_DETERM (rtac @{thm refl} 2)) - THEN (if null moded_clauses then - etac @{thm botE} 1 - else EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses)))) - end; - -(** proof procedure **) - -fun prove_pred options thy clauses preds pred (pol, mode) (moded_clauses, compiled_term) = - let - val ctxt = ProofContext.init_global thy - val clauses = case AList.lookup (op =) clauses pred of SOME rs => rs | NONE => [] - in - Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term - (if not (skip_proof options) then - (fn _ => - rtac @{thm pred_iffI} 1 - THEN print_tac options "after pred_iffI" - THEN prove_one_direction options ctxt clauses preds pred mode moded_clauses - THEN print_tac options "proved one direction" - THEN prove_other_direction options ctxt pred mode moded_clauses - THEN print_tac options "proved other direction") - else (fn _ => Skip_Proof.cheat_tac thy)) - end; (* composition of mode inference, definition, compilation and proof *) diff -r b6acda4d1c29 -r ea46574ca815 src/HOL/Tools/Predicate_Compile/predicate_compile_proof.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_proof.ML Thu Oct 21 19:13:09 2010 +0200 @@ -0,0 +1,525 @@ +(* Title: HOL/Tools/Predicate_Compile/predicate_compile_proof.ML + Author: Lukas Bulwahn, TU Muenchen + +Proof procedure for the compiler from predicates specified by intro/elim rules to equations. +*) + +signature PREDICATE_COMPILE_PROOF = +sig + type indprem = Predicate_Compile_Aux.indprem; + type mode = Predicate_Compile_Aux.mode + val prove_pred : Predicate_Compile_Aux.options -> theory + -> (string * (term list * indprem list) list) list + -> (string * typ) list -> string -> bool * mode + -> (term list * (indprem * Mode_Inference.mode_derivation) list) list * term + -> Thm.thm +end; + +structure Predicate_Compile_Proof : PREDICATE_COMPILE_PROOF = +struct + +open Predicate_Compile_Aux; +open Core_Data; +open Mode_Inference; + +(* debug stuff *) + +fun print_tac options s = + if show_proof_trace options then Tactical.print_tac s else Seq.single; + +(** auxiliary **) + +fun assert b = if not b then raise Fail "Assertion failed" else warning "Assertion holds" + +datatype assertion = Max_number_of_subgoals of int +fun assert_tac (Max_number_of_subgoals i) st = + if (nprems_of st <= i) then Seq.single st + else raise Fail ("assert_tac: Numbers of subgoals mismatch at goal state :" + ^ "\n" ^ Pretty.string_of (Pretty.chunks + (Goal_Display.pretty_goals_without_context st))); + + +(** special setup for simpset **) +val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}]) + setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac)) + setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI})) + +(* auxillary functions *) + +fun is_Type (Type _) = true + | is_Type _ = false + +(* returns true if t is an application of an datatype constructor *) +(* which then consequently would be splitted *) +(* else false *) +fun is_constructor thy t = + if (is_Type (fastype_of t)) then + (case Datatype.get_info thy ((fst o dest_Type o fastype_of) t) of + NONE => false + | SOME info => (let + val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info) + val (c, _) = strip_comb t + in (case c of + Const (name, _) => member (op =) constr_consts name + | _ => false) end)) + else false + +(* MAJOR FIXME: prove_params should be simple + - different form of introrule for parameters ? *) + +fun prove_param options ctxt nargs t deriv = + let + val (f, args) = strip_comb (Envir.eta_contract t) + val mode = head_mode_of deriv + val param_derivations = param_derivations_of deriv + val ho_args = ho_args_of mode args + val f_tac = case f of + Const (name, T) => simp_tac (HOL_basic_ss addsimps + [@{thm eval_pred}, predfun_definition_of ctxt name mode, + @{thm split_eta}, @{thm split_beta}, @{thm fst_conv}, + @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1 + | Free _ => + Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems, asms, concl, schematics} => + let + val prems' = maps dest_conjunct_prem (take nargs prems) + in + MetaSimplifier.rewrite_goal_tac + (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1 + end) ctxt 1 + | Abs _ => raise Fail "prove_param: No valid parameter term" + in + REPEAT_DETERM (rtac @{thm ext} 1) + THEN print_tac options "prove_param" + THEN f_tac + THEN print_tac options "after prove_param" + THEN (REPEAT_DETERM (atac 1)) + THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations)) + THEN REPEAT_DETERM (rtac @{thm refl} 1) + end + +fun prove_expr options ctxt nargs (premposition : int) (t, deriv) = + case strip_comb t of + (Const (name, T), args) => + let + val mode = head_mode_of deriv + val introrule = predfun_intro_of ctxt name mode + val param_derivations = param_derivations_of deriv + val ho_args = ho_args_of mode args + in + print_tac options "before intro rule:" + THEN rtac introrule 1 + THEN print_tac options "after intro rule" + (* for the right assumption in first position *) + THEN rotate_tac premposition 1 + THEN atac 1 + THEN print_tac options "parameter goal" + (* work with parameter arguments *) + THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations)) + THEN (REPEAT_DETERM (atac 1)) + end + | (Free _, _) => + print_tac options "proving parameter call.." + THEN Subgoal.FOCUS_PREMS (fn {context = ctxt, params, prems, asms, concl, schematics} => + let + val param_prem = nth prems premposition + val (param, _) = strip_comb (HOLogic.dest_Trueprop (prop_of param_prem)) + val prems' = maps dest_conjunct_prem (take nargs prems) + fun param_rewrite prem = + param = snd (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of prem))) + val SOME rew_eq = find_first param_rewrite prems' + val param_prem' = MetaSimplifier.rewrite_rule + (map (fn th => th RS @{thm eq_reflection}) + [rew_eq RS @{thm sym}, @{thm split_beta}, @{thm fst_conv}, @{thm snd_conv}]) + param_prem + in + rtac param_prem' 1 + end) ctxt 1 + THEN print_tac options "after prove parameter call" + +fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; + +fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st + +fun check_format ctxt st = + let + val concl' = Logic.strip_assums_concl (hd (prems_of st)) + val concl = HOLogic.dest_Trueprop concl' + val expr = fst (strip_comb (fst (PredicateCompFuns.dest_Eval concl))) + fun valid_expr (Const (@{const_name Predicate.bind}, _)) = true + | valid_expr (Const (@{const_name Predicate.single}, _)) = true + | valid_expr _ = false + in + if valid_expr expr then + ((*tracing "expression is valid";*) Seq.single st) + else + ((*tracing "expression is not valid";*) Seq.empty) (* error "check_format: wrong format" *) + end + +fun prove_match options ctxt nargs out_ts = + let + val thy = ProofContext.theory_of ctxt + val eval_if_P = + @{lemma "P ==> Predicate.eval x z ==> Predicate.eval (if P then x else y) z" by simp} + fun get_case_rewrite t = + if (is_constructor thy t) then let + val case_rewrites = (#case_rewrites (Datatype.the_info thy + ((fst o dest_Type o fastype_of) t))) + in fold (union Thm.eq_thm) (case_rewrites :: map get_case_rewrite (snd (strip_comb t))) [] end + else [] + val simprules = insert Thm.eq_thm @{thm "unit.cases"} (insert Thm.eq_thm @{thm "prod.cases"} + (fold (union Thm.eq_thm) (map get_case_rewrite out_ts) [])) + (* replace TRY by determining if it necessary - are there equations when calling compile match? *) + in + (* make this simpset better! *) + asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1 + THEN print_tac options "after prove_match:" + THEN (DETERM (TRY + (rtac eval_if_P 1 + THEN (SUBPROOF (fn {context = ctxt, params, prems, asms, concl, schematics} => + (REPEAT_DETERM (rtac @{thm conjI} 1 + THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1)))) + THEN print_tac options "if condition to be solved:" + THEN asm_simp_tac HOL_basic_ss' 1 + THEN TRY ( + let + val prems' = maps dest_conjunct_prem (take nargs prems) + in + MetaSimplifier.rewrite_goal_tac + (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1 + end + THEN REPEAT_DETERM (rtac @{thm refl} 1)) + THEN print_tac options "after if simp; in SUBPROOF") ctxt 1)))) + THEN print_tac options "after if simplification" + end; + +(* corresponds to compile_fun -- maybe call that also compile_sidecond? *) + +fun prove_sidecond ctxt t = + let + fun preds_of t nameTs = case strip_comb t of + (f as Const (name, T), args) => + if is_registered ctxt name then (name, T) :: nameTs + else fold preds_of args nameTs + | _ => nameTs + val preds = preds_of t [] + val defs = map + (fn (pred, T) => predfun_definition_of ctxt pred + (all_input_of T)) + preds + in + simp_tac (HOL_basic_ss addsimps + (@{thms HOL.simp_thms} @ (@{thm eval_pred} :: defs))) 1 + (* need better control here! *) + end + +fun prove_clause options ctxt nargs mode (_, clauses) (ts, moded_ps) = + let + val (in_ts, clause_out_ts) = split_mode mode ts; + fun prove_prems out_ts [] = + (prove_match options ctxt nargs out_ts) + THEN print_tac options "before simplifying assumptions" + THEN asm_full_simp_tac HOL_basic_ss' 1 + THEN print_tac options "before single intro rule" + THEN Subgoal.FOCUS_PREMS + (fn {context = ctxt, params = params, prems, asms, concl, schematics} => + let + val prems' = maps dest_conjunct_prem (take nargs prems) + in + MetaSimplifier.rewrite_goal_tac + (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1 + end) ctxt 1 + THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1) + | prove_prems out_ts ((p, deriv) :: ps) = + let + val premposition = (find_index (equal p) clauses) + nargs + val mode = head_mode_of deriv + val rest_tac = + rtac @{thm bindI} 1 + THEN (case p of Prem t => + let + val (_, us) = strip_comb t + val (_, out_ts''') = split_mode mode us + val rec_tac = prove_prems out_ts''' ps + in + print_tac options "before clause:" + (*THEN asm_simp_tac HOL_basic_ss 1*) + THEN print_tac options "before prove_expr:" + THEN prove_expr options ctxt nargs premposition (t, deriv) + THEN print_tac options "after prove_expr:" + THEN rec_tac + end + | Negprem t => + let + val (t, args) = strip_comb t + val (_, out_ts''') = split_mode mode args + val rec_tac = prove_prems out_ts''' ps + val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) + val neg_intro_rule = + Option.map (fn name => + the (predfun_neg_intro_of ctxt name mode)) name + val param_derivations = param_derivations_of deriv + val params = ho_args_of mode args + in + print_tac options "before prove_neg_expr:" + THEN full_simp_tac (HOL_basic_ss addsimps + [@{thm split_eta}, @{thm split_beta}, @{thm fst_conv}, + @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1 + THEN (if (is_some name) then + print_tac options "before applying not introduction rule" + THEN Subgoal.FOCUS_PREMS + (fn {context = ctxt, params = params, prems, asms, concl, schematics} => + rtac (the neg_intro_rule) 1 + THEN rtac (nth prems premposition) 1) ctxt 1 + THEN print_tac options "after applying not introduction rule" + THEN (EVERY (map2 (prove_param options ctxt nargs) params param_derivations)) + THEN (REPEAT_DETERM (atac 1)) + else + rtac @{thm not_predI'} 1 + (* test: *) + THEN dtac @{thm sym} 1 + THEN asm_full_simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1) + THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1 + THEN rec_tac + end + | Sidecond t => + rtac @{thm if_predI} 1 + THEN print_tac options "before sidecond:" + THEN prove_sidecond ctxt t + THEN print_tac options "after sidecond:" + THEN prove_prems [] ps) + in (prove_match options ctxt nargs out_ts) + THEN rest_tac + end; + val prems_tac = prove_prems in_ts moded_ps + in + print_tac options "Proving clause..." + THEN rtac @{thm bindI} 1 + THEN rtac @{thm singleI} 1 + THEN prems_tac + end; + +fun select_sup 1 1 = [] + | select_sup _ 1 = [rtac @{thm supI1}] + | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1)); + +fun prove_one_direction options ctxt clauses preds pred mode moded_clauses = + let + val T = the (AList.lookup (op =) preds pred) + val nargs = length (binder_types T) + val pred_case_rule = the_elim_of ctxt pred + in + REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})) + THEN print_tac options "before applying elim rule" + THEN etac (predfun_elim_of ctxt pred mode) 1 + THEN etac pred_case_rule 1 + THEN print_tac options "after applying elim rule" + THEN (EVERY (map + (fn i => EVERY' (select_sup (length moded_clauses) i) i) + (1 upto (length moded_clauses)))) + THEN (EVERY (map2 (prove_clause options ctxt nargs mode) clauses moded_clauses)) + THEN print_tac options "proved one direction" + end; + +(** Proof in the other direction **) + +fun prove_match2 options ctxt out_ts = + let + val thy = ProofContext.theory_of ctxt + fun split_term_tac (Free _) = all_tac + | split_term_tac t = + if (is_constructor thy t) then + let + val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t) + val num_of_constrs = length (#case_rewrites info) + val (_, ts) = strip_comb t + in + print_tac options ("Term " ^ (Syntax.string_of_term ctxt t) ^ + "splitting with rules \n" ^ Display.string_of_thm ctxt (#split_asm info)) + THEN TRY ((Splitter.split_asm_tac [#split_asm info] 1) + THEN (print_tac options "after splitting with split_asm rules") + (* THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1) + THEN (DETERM (TRY (etac @{thm Pair_inject} 1)))*) + THEN (REPEAT_DETERM_N (num_of_constrs - 1) + (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))) + THEN (assert_tac (Max_number_of_subgoals 2)) + THEN (EVERY (map split_term_tac ts)) + end + else all_tac + in + split_term_tac (HOLogic.mk_tuple out_ts) + THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) + THEN (etac @{thm botE} 2)))) + end + +(* VERY LARGE SIMILIRATIY to function prove_param +-- join both functions +*) +(* TODO: remove function *) + +fun prove_param2 options ctxt t deriv = + let + val (f, args) = strip_comb (Envir.eta_contract t) + val mode = head_mode_of deriv + val param_derivations = param_derivations_of deriv + val ho_args = ho_args_of mode args + val f_tac = case f of + Const (name, T) => full_simp_tac (HOL_basic_ss addsimps + (@{thm eval_pred}::(predfun_definition_of ctxt name mode) + :: @{thm "Product_Type.split_conv"}::[])) 1 + | Free _ => all_tac + | _ => error "prove_param2: illegal parameter term" + in + print_tac options "before simplification in prove_args:" + THEN f_tac + THEN print_tac options "after simplification in prove_args" + THEN EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations) + end + +fun prove_expr2 options ctxt (t, deriv) = + (case strip_comb t of + (Const (name, T), args) => + let + val mode = head_mode_of deriv + val param_derivations = param_derivations_of deriv + val ho_args = ho_args_of mode args + in + etac @{thm bindE} 1 + THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))) + THEN print_tac options "prove_expr2-before" + THEN etac (predfun_elim_of ctxt name mode) 1 + THEN print_tac options "prove_expr2" + THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations)) + THEN print_tac options "finished prove_expr2" + end + | _ => etac @{thm bindE} 1) + +fun prove_sidecond2 options ctxt t = let + fun preds_of t nameTs = case strip_comb t of + (f as Const (name, T), args) => + if is_registered ctxt name then (name, T) :: nameTs + else fold preds_of args nameTs + | _ => nameTs + val preds = preds_of t [] + val defs = map + (fn (pred, T) => predfun_definition_of ctxt pred + (all_input_of T)) + preds + in + (* only simplify the one assumption *) + full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 + (* need better control here! *) + THEN print_tac options "after sidecond2 simplification" + end + +fun prove_clause2 options ctxt pred mode (ts, ps) i = + let + val pred_intro_rule = nth (intros_of ctxt pred) (i - 1) + val (in_ts, clause_out_ts) = split_mode mode ts; + val split_ss = HOL_basic_ss' addsimps [@{thm split_eta}, @{thm split_beta}, + @{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}] + fun prove_prems2 out_ts [] = + print_tac options "before prove_match2 - last call:" + THEN prove_match2 options ctxt out_ts + THEN print_tac options "after prove_match2 - last call:" + THEN (etac @{thm singleE} 1) + THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1)) + THEN (asm_full_simp_tac HOL_basic_ss' 1) + THEN TRY ( + (REPEAT_DETERM (etac @{thm Pair_inject} 1)) + THEN (asm_full_simp_tac HOL_basic_ss' 1) + + THEN SOLVED (print_tac options "state before applying intro rule:" + THEN (rtac pred_intro_rule + (* How to handle equality correctly? *) + THEN_ALL_NEW (K (print_tac options "state before assumption matching") + THEN' (atac ORELSE' ((CHANGED o asm_full_simp_tac split_ss) THEN' (TRY o atac))) + THEN' (K (print_tac options "state after pre-simplification:")) + THEN' (K (print_tac options "state after assumption matching:")))) 1)) + | prove_prems2 out_ts ((p, deriv) :: ps) = + let + val mode = head_mode_of deriv + val rest_tac = (case p of + Prem t => + let + val (_, us) = strip_comb t + val (_, out_ts''') = split_mode mode us + val rec_tac = prove_prems2 out_ts''' ps + in + (prove_expr2 options ctxt (t, deriv)) THEN rec_tac + end + | Negprem t => + let + val (_, args) = strip_comb t + val (_, out_ts''') = split_mode mode args + val rec_tac = prove_prems2 out_ts''' ps + val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) + val param_derivations = param_derivations_of deriv + val ho_args = ho_args_of mode args + in + print_tac options "before neg prem 2" + THEN etac @{thm bindE} 1 + THEN (if is_some name then + full_simp_tac (HOL_basic_ss addsimps + [predfun_definition_of ctxt (the name) mode]) 1 + THEN etac @{thm not_predE} 1 + THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1 + THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations)) + else + etac @{thm not_predE'} 1) + THEN rec_tac + end + | Sidecond t => + etac @{thm bindE} 1 + THEN etac @{thm if_predE} 1 + THEN prove_sidecond2 options ctxt t + THEN prove_prems2 [] ps) + in print_tac options "before prove_match2:" + THEN prove_match2 options ctxt out_ts + THEN print_tac options "after prove_match2:" + THEN rest_tac + end; + val prems_tac = prove_prems2 in_ts ps + in + print_tac options "starting prove_clause2" + THEN etac @{thm bindE} 1 + THEN (etac @{thm singleE'} 1) + THEN (TRY (etac @{thm Pair_inject} 1)) + THEN print_tac options "after singleE':" + THEN prems_tac + end; + +fun prove_other_direction options ctxt pred mode moded_clauses = + let + fun prove_clause clause i = + (if i < length moded_clauses then etac @{thm supE} 1 else all_tac) + THEN (prove_clause2 options ctxt pred mode clause i) + in + (DETERM (TRY (rtac @{thm unit.induct} 1))) + THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all}))) + THEN (rtac (predfun_intro_of ctxt pred mode) 1) + THEN (REPEAT_DETERM (rtac @{thm refl} 2)) + THEN (if null moded_clauses then + etac @{thm botE} 1 + else EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses)))) + end; + +(** proof procedure **) + +fun prove_pred options thy clauses preds pred (pol, mode) (moded_clauses, compiled_term) = + let + val ctxt = ProofContext.init_global thy + val clauses = case AList.lookup (op =) clauses pred of SOME rs => rs | NONE => [] + in + Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term + (if not (skip_proof options) then + (fn _ => + rtac @{thm pred_iffI} 1 + THEN print_tac options "after pred_iffI" + THEN prove_one_direction options ctxt clauses preds pred mode moded_clauses + THEN print_tac options "proved one direction" + THEN prove_other_direction options ctxt pred mode moded_clauses + THEN print_tac options "proved other direction") + else (fn _ => Skip_Proof.cheat_tac thy)) + end; + +end; \ No newline at end of file