# HG changeset patch # User boehmes # Date 1292398764 -3600 # Node ID e0bd443c0fdd8f191bb5ca7d61235d511f9b7a7f # Parent 4a9eec045f2a26d7be203aff1b86bd3d326bb431 re-ordered SMT normalization code (eta-normalization, lambda abstractions and partial functions will be dealt with on the term level); added simple trigger inference mechanism; added syntactic checks for triggers and quantifier weights; factored out the normalization of special quantifiers (used to be in the eta-normalization part); normalization now unfolds abs/min/max (not SMT-LIB-specific); rules for pairs and function update are not anymore added automatically to the problem; more aggressive rewriting of natural number operations into integer operations (minimizes the number of remaining nat-int coercions); normalizations are now managed in a class-based manner (similar to built-in symbols) diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/SMT.thy --- a/src/HOL/SMT.thy Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/SMT.thy Wed Dec 15 08:39:24 2010 +0100 @@ -130,18 +130,20 @@ definition z3mod :: "int \ int \ int" where "z3mod k l = (if 0 \ l then k mod l else k mod (-l))" -lemma div_by_z3div: "k div l = ( - if k = 0 \ l = 0 then 0 - else if (0 < k \ 0 < l) \ (k < 0 \ 0 < l) then z3div k l - else z3div (-k) (-l))" - by (auto simp add: z3div_def) +lemma div_by_z3div: + "\k l. k div l = ( + if k = 0 \ l = 0 then 0 + else if (0 < k \ 0 < l) \ (k < 0 \ 0 < l) then z3div k l + else z3div (-k) (-l))" + by (auto simp add: z3div_def trigger_def) -lemma mod_by_z3mod: "k mod l = ( - if l = 0 then k - else if k = 0 then 0 - else if (0 < k \ 0 < l) \ (k < 0 \ 0 < l) then z3mod k l - else - z3mod (-k) (-l))" - by (auto simp add: z3mod_def) +lemma mod_by_z3mod: + "\k l. k mod l = ( + if l = 0 then k + else if k = 0 then 0 + else if (0 < k \ 0 < l) \ (k < 0 \ 0 < l) then z3mod k l + else - z3mod (-k) (-l))" + by (auto simp add: z3mod_def trigger_def) diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/SMT/smt_builtin.ML --- a/src/HOL/Tools/SMT/smt_builtin.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/SMT/smt_builtin.ML Wed Dec 15 08:39:24 2010 +0100 @@ -37,6 +37,7 @@ val is_builtin_fun: Proof.context -> string * typ -> term list -> bool val is_builtin_pred: Proof.context -> string * typ -> term list -> bool val is_builtin_conn: Proof.context -> string * typ -> term list -> bool + val is_builtin_fun_ext: Proof.context -> string * typ -> term list -> bool val is_builtin_ext: Proof.context -> string * typ -> term list -> bool end @@ -78,8 +79,6 @@ type ('a, 'b) btab = ('a, 'b) ttab Symtab.table -fun empty_btab () = Symtab.empty - fun insert_btab cs n T f = Symtab.map_default (n, []) (insert_ttab cs T f) @@ -147,26 +146,10 @@ type 'a bfun = Proof.context -> typ -> term list -> 'a -fun true3 _ _ _ = true - -fun raw_add_builtin_fun_ext thy cs n = - insert_btab cs n (Sign.the_const_type thy n) (Ext true3) - -val basic_builtin_fun_names = [ - @{const_name SMT.pat}, @{const_name SMT.nopat}, - @{const_name SMT.trigger}, @{const_name SMT.weight}] - -type builtin_funcs = (bool bfun, (string * term list) option bfun) btab - -fun basic_builtin_funcs () : builtin_funcs = - empty_btab () - |> fold (raw_add_builtin_fun_ext @{theory} U.basicC) basic_builtin_fun_names - (* FIXME: SMT_Normalize should check that they are properly used *) - structure Builtin_Funcs = Generic_Data ( - type T = builtin_funcs - val empty = basic_builtin_funcs () + type T = (bool bfun, (string * term list) option bfun) btab + val empty = Symtab.empty val extend = I val merge = merge_btab ) @@ -180,7 +163,8 @@ fun add_builtin_fun_ext ((n, T), f) = Builtin_Funcs.map (insert_btab U.basicC n T (Ext f)) -fun add_builtin_fun_ext' c = add_builtin_fun_ext (c, true3) +fun add_builtin_fun_ext' c = + add_builtin_fun_ext (c, fn _ => fn _ => fn _ => true) fun add_builtin_fun_ext'' n context = let val thy = Context.theory_of context diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/SMT/smt_normalize.ML --- a/src/HOL/Tools/SMT/smt_normalize.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/SMT/smt_normalize.ML Wed Dec 15 08:39:24 2010 +0100 @@ -1,28 +1,17 @@ (* Title: HOL/Tools/SMT/smt_normalize.ML Author: Sascha Boehme, TU Muenchen -Normalization steps on theorems required by SMT solvers: - * simplify trivial distincts (those with less than three elements), - * rewrite bool case expressions as if expressions, - * normalize numerals (e.g. replace negative numerals by negated positive - numerals), - * embed natural numbers into integers, - * add extra rules specifying types and constants which occur frequently, - * fully translate into object logic, add universal closure, - * monomorphize (create instances of schematic rules), - * lift lambda terms, - * make applications explicit for functions with varying number of arguments. - * add (hypothetical definitions for) missing datatype selectors, +Normalization steps on theorems required by SMT solvers. *) signature SMT_NORMALIZE = sig - type extra_norm = bool -> (int * thm) list -> Proof.context -> - (int * thm) list * Proof.context - val normalize: extra_norm -> bool -> (int * thm) list -> Proof.context -> - (int * thm) list * Proof.context val atomize_conv: Proof.context -> conv - val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv + type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list + val add_extra_norm: SMT_Utils.class * extra_norm -> Context.generic -> + Context.generic + val normalize: Proof.context -> (int * (int option * thm)) list -> + (int * thm) list val setup: theory -> theory end @@ -32,12 +21,10 @@ structure U = SMT_Utils structure B = SMT_Builtin -infix 2 ?? -fun (test ?? f) x = if test x then f x else x - +(* general theorem normalizations *) -(* instantiate elimination rules *) +(** instantiate elimination rules **) local val (cpfalse, cfalse) = `U.mk_cprop (Thm.cterm_of @{theory} @{const False}) @@ -56,281 +43,18 @@ end - -(* simplification of trivial distincts (distinct should have at least - three elements in the argument list) *) - -local - fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) = - (case try HOLogic.dest_list t of - SOME [] => true - | SOME [_] => true - | _ => false) - | is_trivial_distinct _ = false - - val thms = map mk_meta_eq @{lemma - "distinct [] = True" - "distinct [x] = True" - "distinct [x, y] = (x ~= y)" - by simp_all} - fun distinct_conv _ = - U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms) -in -fun trivial_distinct ctxt = - map (apsnd ((Term.exists_subterm is_trivial_distinct o Thm.prop_of) ?? - Conv.fconv_rule (Conv.top_conv distinct_conv ctxt))) -end - - - -(* rewrite bool case expressions as if expressions *) - -local - val is_bool_case = (fn - Const (@{const_name "bool.bool_case"}, _) $ _ $ _ $ _ => true - | _ => false) +(** normalize definitions **) - val thm = mk_meta_eq @{lemma - "(case P of True => x | False => y) = (if P then x else y)" by simp} - val unfold_conv = U.if_true_conv is_bool_case (Conv.rewr_conv thm) -in -fun rewrite_bool_cases ctxt = - map (apsnd ((Term.exists_subterm is_bool_case o Thm.prop_of) ?? - Conv.fconv_rule (Conv.top_conv (K unfold_conv) ctxt))) - -val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"} - -end - - - -(* normalization of numerals: rewriting of negative integer numerals into - positive numerals, Numeral0 into 0, Numeral1 into 1 *) - -local - fun is_number_sort ctxt T = - Sign.of_sort (ProofContext.theory_of ctxt) (T, @{sort number_ring}) - - fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) = - (case try HOLogic.dest_number t of - SOME (T, i) => is_number_sort ctxt T andalso i < 2 - | NONE => false) - | is_strange_number _ _ = false - - val pos_numeral_ss = HOL_ss - addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}] - addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}] - addsimps @{thms Int.pred_bin_simps} - addsimps @{thms Int.normalize_bin_simps} - addsimps @{lemma - "Int.Min = - Int.Bit1 Int.Pls" - "Int.Bit0 (- Int.Pls) = - Int.Pls" - "Int.Bit0 (- k) = - Int.Bit0 k" - "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)" - by simp_all (simp add: pred_def)} - - fun pos_conv ctxt = U.if_conv (is_strange_number ctxt) - (Simplifier.rewrite (Simplifier.context ctxt pos_numeral_ss)) - Conv.no_conv -in -fun normalize_numerals ctxt = - map (apsnd ((Term.exists_subterm (is_strange_number ctxt) o Thm.prop_of) ?? - Conv.fconv_rule (Conv.top_sweep_conv pos_conv ctxt))) -end - +fun norm_def thm = + (case Thm.prop_of thm of + @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) => + norm_def (thm RS @{thm fun_cong}) + | Const (@{const_name "=="}, _) $ _ $ Abs _ => + norm_def (thm RS @{thm meta_eq_to_obj_eq}) + | _ => thm) -(* embedding of standard natural number operations into integer operations *) - -local - val nat_embedding = map (pair ~1) @{lemma - "nat (int n) = n" - "i >= 0 --> int (nat i) = i" - "i < 0 --> int (nat i) = 0" - by simp_all} - - val nat_rewriting = @{lemma - "0 = nat 0" - "1 = nat 1" - "(number_of :: int => nat) = (%i. nat (number_of i))" - "int (nat 0) = 0" - "int (nat 1) = 1" - "op < = (%a b. int a < int b)" - "op <= = (%a b. int a <= int b)" - "Suc = (%a. nat (int a + 1))" - "op + = (%a b. nat (int a + int b))" - "op - = (%a b. nat (int a - int b))" - "op * = (%a b. nat (int a * int b))" - "op div = (%a b. nat (int a div int b))" - "op mod = (%a b. nat (int a mod int b))" - "min = (%a b. nat (min (int a) (int b)))" - "max = (%a b. nat (max (int a) (int b)))" - "int (nat (int a + int b)) = int a + int b" - "int (nat (int a + 1)) = int a + 1" (* special rule due to Suc above *) - "int (nat (int a * int b)) = int a * int b" - "int (nat (int a div int b)) = int a div int b" - "int (nat (int a mod int b)) = int a mod int b" - "int (nat (min (int a) (int b))) = min (int a) (int b)" - "int (nat (max (int a) (int b))) = max (int a) (int b)" - by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib - nat_mod_distrib int_mult[symmetric] zdiv_int[symmetric] - zmod_int[symmetric])} - - fun on_positive num f x = - (case try HOLogic.dest_number (Thm.term_of num) of - SOME (_, i) => if i >= 0 then SOME (f x) else NONE - | NONE => NONE) - - val cancel_int_nat_ss = HOL_ss - addsimps [@{thm Nat_Numeral.nat_number_of}] - addsimps [@{thm Nat_Numeral.int_nat_number_of}] - addsimps @{thms neg_simps} - - val int_eq = Thm.cterm_of @{theory} @{const "==" (int)} - - fun cancel_int_nat_simproc _ ss ct = - let - val num = Thm.dest_arg (Thm.dest_arg ct) - val goal = Thm.mk_binop int_eq ct num - val simpset = Simplifier.inherit_context ss cancel_int_nat_ss - fun tac _ = Simplifier.simp_tac simpset 1 - in on_positive num (Goal.prove_internal [] goal) tac end - - val nat_ss = HOL_ss - addsimps nat_rewriting - addsimprocs [ - Simplifier.make_simproc { - name = "cancel_int_nat_num", lhss = [@{cpat "int (nat _)"}], - proc = cancel_int_nat_simproc, identifier = [] }] - - fun conv ctxt = Simplifier.rewrite (Simplifier.context ctxt nat_ss) - - val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat})) - val uses_nat_int = Term.exists_subterm (member (op aconv) - [@{const of_nat (int)}, @{const nat}]) - - val nat_ops = [ - @{const less (nat)}, @{const less_eq (nat)}, - @{const Suc}, @{const plus (nat)}, @{const minus (nat)}, - @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}] - val nat_ops' = @{const of_nat (int)} :: @{const nat} :: nat_ops -in -fun nat_as_int ctxt = - map (apsnd ((uses_nat_type o Thm.prop_of) ?? Conv.fconv_rule (conv ctxt))) #> - exists (uses_nat_int o Thm.prop_of o snd) ?? append nat_embedding - -val setup_nat_as_int = - B.add_builtin_typ_ext (@{typ nat}, K true) #> - fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops' -end - - - -(* further normalizations: beta/eta, universal closure, atomize *) - -val eta_expand_eq = @{lemma "f == (%x. f x)" by (rule reflexive)} - -fun eta_expand_conv cv ctxt = - Conv.rewr_conv eta_expand_eq then_conv Conv.abs_conv (cv o snd) ctxt - -local - val eta_conv = eta_expand_conv - - fun args_conv cv ct = - (case Thm.term_of ct of - _ $ _ => Conv.combination_conv (args_conv cv) cv - | _ => Conv.all_conv) ct - - fun eta_args_conv cv 0 = args_conv o cv - | eta_args_conv cv i = eta_conv (eta_args_conv cv (i-1)) - - fun keep_conv ctxt = Conv.binder_conv (norm_conv o snd) ctxt - and eta_binder_conv ctxt = Conv.arg_conv (eta_conv norm_conv ctxt) - and keep_let_conv ctxt = Conv.combination_conv - (Conv.arg_conv (norm_conv ctxt)) (Conv.abs_conv (norm_conv o snd) ctxt) - and unfold_let_conv ctxt = Conv.combination_conv - (Conv.arg_conv (norm_conv ctxt)) (eta_conv norm_conv ctxt) - and unfold_conv thm ctxt = Conv.rewr_conv thm then_conv keep_conv ctxt - and unfold_ex1_conv ctxt = unfold_conv @{thm Ex1_def} ctxt - and unfold_ball_conv ctxt = unfold_conv (mk_meta_eq @{thm Ball_def}) ctxt - and unfold_bex_conv ctxt = unfold_conv (mk_meta_eq @{thm Bex_def}) ctxt - and norm_conv ctxt ct = - (case Thm.term_of ct of - Const (@{const_name All}, _) $ Abs _ => keep_conv - | Const (@{const_name All}, _) $ _ => eta_binder_conv - | Const (@{const_name All}, _) => eta_conv eta_binder_conv - | Const (@{const_name Ex}, _) $ Abs _ => keep_conv - | Const (@{const_name Ex}, _) $ _ => eta_binder_conv - | Const (@{const_name Ex}, _) => eta_conv eta_binder_conv - | Const (@{const_name Let}, _) $ _ $ Abs _ => keep_let_conv - | Const (@{const_name Let}, _) $ _ $ _ => unfold_let_conv - | Const (@{const_name Let}, _) $ _ => eta_conv unfold_let_conv - | Const (@{const_name Let}, _) => eta_conv (eta_conv unfold_let_conv) - | Const (@{const_name Ex1}, _) $ _ => unfold_ex1_conv - | Const (@{const_name Ex1}, _) => eta_conv unfold_ex1_conv - | Const (@{const_name Ball}, _) $ _ $ _ => unfold_ball_conv - | Const (@{const_name Ball}, _) $ _ => eta_conv unfold_ball_conv - | Const (@{const_name Ball}, _) => eta_conv (eta_conv unfold_ball_conv) - | Const (@{const_name Bex}, _) $ _ $ _ => unfold_bex_conv - | Const (@{const_name Bex}, _) $ _ => eta_conv unfold_bex_conv - | Const (@{const_name Bex}, _) => eta_conv (eta_conv unfold_bex_conv) - | Abs _ => Conv.abs_conv (norm_conv o snd) - | _ => - (case Term.strip_comb (Thm.term_of ct) of - (Const (c as (_, T)), ts) => - if SMT_Builtin.is_builtin_fun ctxt c ts - then eta_args_conv norm_conv - (length (Term.binder_types T) - length ts) - else args_conv o norm_conv - | _ => args_conv o norm_conv)) ctxt ct - - fun is_normed ctxt t = - (case t of - Const (@{const_name All}, _) $ Abs (_, _, u) => is_normed ctxt u - | Const (@{const_name All}, _) $ _ => false - | Const (@{const_name All}, _) => false - | Const (@{const_name Ex}, _) $ Abs (_, _, u) => is_normed ctxt u - | Const (@{const_name Ex}, _) $ _ => false - | Const (@{const_name Ex}, _) => false - | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) => - is_normed ctxt u1 andalso is_normed ctxt u2 - | Const (@{const_name Let}, _) $ _ $ _ => false - | Const (@{const_name Let}, _) $ _ => false - | Const (@{const_name Let}, _) => false - | Const (@{const_name Ex1}, _) $ _ => false - | Const (@{const_name Ex1}, _) => false - | Const (@{const_name Ball}, _) $ _ $ _ => false - | Const (@{const_name Ball}, _) $ _ => false - | Const (@{const_name Ball}, _) => false - | Const (@{const_name Bex}, _) $ _ $ _ => false - | Const (@{const_name Bex}, _) $ _ => false - | Const (@{const_name Bex}, _) => false - | Abs (_, _, u) => is_normed ctxt u - | _ => - (case Term.strip_comb t of - (Const (c as (_, T)), ts) => - if SMT_Builtin.is_builtin_fun ctxt c ts - then length (Term.binder_types T) = length ts andalso - forall (is_normed ctxt) ts - else forall (is_normed ctxt) ts - | (_, ts) => forall (is_normed ctxt) ts)) -in -fun norm_binder_conv ctxt = - U.if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt) - -val setup_unfolded_quants = - fold B.add_builtin_fun_ext'' [@{const_name Ball}, @{const_name Bex}, - @{const_name Ex1}] - -end - -fun norm_def ctxt thm = - (case Thm.prop_of thm of - @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ _ $ Abs _) => - norm_def ctxt (thm RS @{thm fun_cong}) - | Const (@{const_name "=="}, _) $ _ $ Abs _ => - norm_def ctxt (thm RS @{thm meta_eq_to_obj_eq}) - | _ => thm) +(** atomization **) fun atomize_conv ctxt ct = (case Thm.term_of ct of @@ -349,243 +73,543 @@ fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="}, @{const_name all}, @{const_name Trueprop}] -fun normalize_rule ctxt = - Conv.fconv_rule ( - (* reduce lambda abstractions, except at known binders: *) - Thm.beta_conversion true then_conv - Thm.eta_conversion then_conv - norm_binder_conv ctxt) #> - norm_def ctxt #> - Drule.forall_intr_vars #> - Conv.fconv_rule (atomize_conv ctxt) - - -(* lift lambda terms into additional rules *) +(** unfold special quantifiers **) local - fun used_vars cvs ct = - let - val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs) - val add = (fn SOME ct => insert (op aconvc) ct | _ => I) - in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end - - fun apply cv thm = - let val thm' = Thm.combination thm (Thm.reflexive cv) - in Thm.transitive thm' (Thm.beta_conversion false (Thm.rhs_of thm')) end - fun apply_def cvs eq = Thm.symmetric (fold apply cvs eq) + val ex1_def = mk_meta_eq @{lemma + "Ex1 = (%P. EX x. P x & (ALL y. P y --> y = x))" + by (rule ext) (simp only: Ex1_def)} - fun replace_lambda cvs ct (cx as (ctxt, defs)) = - let - val cvs' = used_vars cvs ct - val ct' = fold_rev Thm.cabs cvs' ct - in - (case Termtab.lookup defs (Thm.term_of ct') of - SOME eq => (apply_def cvs' eq, cx) - | NONE => - let - val {T, ...} = Thm.rep_cterm ct' and n = Name.uu - val (n', ctxt') = yield_singleton Variable.variant_fixes n ctxt - val cu = U.mk_cequals (U.certify ctxt (Free (n', T))) ct' - val (eq, ctxt'') = yield_singleton Assumption.add_assumes cu ctxt' - val defs' = Termtab.update (Thm.term_of ct', eq) defs - in (apply_def cvs' eq, (ctxt'', defs')) end) - end + val ball_def = mk_meta_eq @{lemma "Ball = (%A P. ALL x. x : A --> P x)" + by (rule ext)+ (rule Ball_def)} + + val bex_def = mk_meta_eq @{lemma "Bex = (%A P. EX x. x : A & P x)" + by (rule ext)+ (rule Bex_def)} - fun none ct cx = (Thm.reflexive ct, cx) - fun in_comb f g ct cx = - let val (cu1, cu2) = Thm.dest_comb ct - in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end - fun in_arg f = in_comb none f - fun in_abs f cvs ct (ctxt, defs) = - let - val (n, ctxt') = yield_singleton Variable.variant_fixes Name.uu ctxt - val (cv, cu) = Thm.dest_abs (SOME n) ct - in (ctxt', defs) |> f (cv :: cvs) cu |>> Thm.abstract_rule n cv end - - fun traverse cvs ct = - (case Thm.term_of ct of - Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs traverse cvs) - | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs traverse cvs) - | Const (@{const_name Let}, _) $ _ $ Abs _ => - in_comb (in_arg (traverse cvs)) (in_abs traverse cvs) - | Abs _ => at_lambda cvs - | _ $ _ => in_comb (traverse cvs) (traverse cvs) - | _ => none) ct + val special_quants = [(@{const_name Ex1}, ex1_def), + (@{const_name Ball}, ball_def), (@{const_name Bex}, bex_def)] + + fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n + | special_quant _ = NONE - and at_lambda cvs ct = - in_abs traverse cvs ct #-> (fn thm => - replace_lambda cvs (Thm.rhs_of thm) #>> Thm.transitive thm) + fun special_quant_conv _ ct = + (case special_quant (Thm.term_of ct) of + SOME thm => Conv.rewr_conv thm + | NONE => Conv.all_conv) ct +in - fun has_free_lambdas t = - (case t of - Const (@{const_name All}, _) $ Abs (_, _, u) => has_free_lambdas u - | Const (@{const_name Ex}, _) $ Abs (_, _, u) => has_free_lambdas u - | Const (@{const_name Let}, _) $ u1 $ Abs (_, _, u2) => - has_free_lambdas u1 orelse has_free_lambdas u2 - | Abs _ => true - | u1 $ u2 => has_free_lambdas u1 orelse has_free_lambdas u2 - | _ => false) +fun unfold_special_quants_conv ctxt = + U.if_exists_conv (is_some o special_quant) + (Conv.top_conv special_quant_conv ctxt) - fun lift_lm f thm cx = - if not (has_free_lambdas (Thm.prop_of thm)) then (thm, cx) - else cx |> f (Thm.cprop_of thm) |>> (fn thm' => Thm.equal_elim thm' thm) -in -fun lift_lambdas irules ctxt = - let - val cx = (ctxt, Termtab.empty) - val (idxs, thms) = split_list irules - val (thms', (ctxt', defs)) = fold_map (lift_lm (traverse [])) thms cx - val eqs = Termtab.fold (cons o normalize_rule ctxt' o snd) defs [] - in (map (pair ~1) eqs @ (idxs ~~ thms'), ctxt') end +val setup_unfolded_quants = fold (B.add_builtin_fun_ext'' o fst) special_quants + end - -(* make application explicit for functions with varying number of arguments *) +(** trigger inference **) local - val const = prefix "c" and free = prefix "f" - fun min i (e as (_, j)) = if i <> j then (true, Int.min (i, j)) else e - fun add t i = Symtab.map_default (t, (false, i)) (min i) - fun traverse t = + (*** check trigger syntax ***) + + fun dest_trigger (Const (@{const_name pat}, _) $ _) = SOME true + | dest_trigger (Const (@{const_name nopat}, _) $ _) = SOME false + | dest_trigger _ = NONE + + fun eq_list [] = false + | eq_list (b :: bs) = forall (equal b) bs + + fun proper_trigger t = + t + |> these o try HOLogic.dest_list + |> map (map_filter dest_trigger o these o try HOLogic.dest_list) + |> (fn [] => false | bss => forall eq_list bss) + + fun proper_quant inside f t = + (case t of + Const (@{const_name All}, _) $ Abs (_, _, u) => proper_quant true f u + | Const (@{const_name Ex}, _) $ Abs (_, _, u) => proper_quant true f u + | @{const trigger} $ p $ u => + (if inside then f p else false) andalso proper_quant false f u + | Abs (_, _, u) => proper_quant false f u + | u1 $ u2 => proper_quant false f u1 andalso proper_quant false f u2 + | _ => true) + + fun check_trigger_error ctxt t = + error ("SMT triggers must only occur under quantifier and multipatterns " ^ + "must have the same kind: " ^ Syntax.string_of_term ctxt t) + + fun check_trigger_conv ctxt ct = + if proper_quant false proper_trigger (Thm.term_of ct) then Conv.all_conv ct + else check_trigger_error ctxt (Thm.term_of ct) + + + (*** infer simple triggers ***) + + fun dest_cond_eq ct = + (case Thm.term_of ct of + Const (@{const_name HOL.eq}, _) $ _ $ _ => Thm.dest_binop ct + | @{const HOL.implies} $ _ $ _ => dest_cond_eq (Thm.dest_arg ct) + | _ => raise CTERM ("no equation", [ct])) + + fun get_constrs thy (Type (n, _)) = these (Datatype.get_constrs thy n) + | get_constrs _ _ = [] + + fun is_constr thy (n, T) = + let fun match (m, U) = m = n andalso Sign.typ_instance thy (T, U) + in can (the o find_first match o get_constrs thy o Term.body_type) T end + + fun is_constr_pat thy t = + (case Term.strip_comb t of + (Free _, []) => true + | (Const c, ts) => is_constr thy c andalso forall (is_constr_pat thy) ts + | _ => false) + + fun is_simp_lhs ctxt t = (case Term.strip_comb t of - (Const (n, _), ts) => add (const n) (length ts) #> fold traverse ts - | (Free (n, _), ts) => add (free n) (length ts) #> fold traverse ts - | (Abs (_, _, u), ts) => fold traverse (u :: ts) - | (_, ts) => fold traverse ts) - fun prune tab = Symtab.fold (fn (n, (true, i)) => - Symtab.update (n, i) | _ => I) tab Symtab.empty + (Const c, ts as _ :: _) => + not (B.is_builtin_fun_ext ctxt c ts) andalso + forall (is_constr_pat (ProofContext.theory_of ctxt)) ts + | _ => false) + + fun has_all_vars vs t = + subset (op aconv) (vs, map Free (Term.add_frees t [])) + + fun minimal_pats vs ct = + if has_all_vars vs (Thm.term_of ct) then + (case Thm.term_of ct of + _ $ _ => + (case pairself (minimal_pats vs) (Thm.dest_comb ct) of + ([], []) => [[ct]] + | (ctss, ctss') => union (eq_set (op aconvc)) ctss ctss') + | _ => [[ct]]) + else [] + + fun proper_mpat _ _ _ [] = false + | proper_mpat thy gen u cts = + let + val tps = (op ~~) (`gen (map Thm.term_of cts)) + fun some_match u = tps |> exists (fn (t', t) => + Pattern.matches thy (t', u) andalso not (t aconv u)) + in not (Term.exists_subterm some_match u) end + + val pat = U.mk_const_pat @{theory} @{const_name SMT.pat} U.destT1 + fun mk_pat ct = Thm.capply (U.instT' ct pat) ct - fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2 - fun nary_conv conv1 conv2 ct = - (Conv.combination_conv (nary_conv conv1 conv2) conv2 else_conv conv1) ct - fun abs_conv conv tb = Conv.abs_conv (fn (cv, cx) => - let val n = fst (Term.dest_Free (Thm.term_of cv)) - in conv (Symtab.update (free n, 0) tb) cx end) - val fun_app_rule = @{lemma "f x == fun_app f x" by (simp add: fun_app_def)} + fun mk_clist T = pairself (Thm.cterm_of @{theory}) + (HOLogic.cons_const T, HOLogic.nil_const T) + fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil + val mk_pat_list = mk_list (mk_clist @{typ SMT.pattern}) + val mk_mpat_list = mk_list (mk_clist @{typ "SMT.pattern list"}) + fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss + + val trigger_eq = + mk_meta_eq @{lemma "p = SMT.trigger t p" by (simp add: trigger_def)} + + fun insert_trigger_conv [] ct = Conv.all_conv ct + | insert_trigger_conv ctss ct = + let val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct + in Thm.instantiate ([], [cp, (ctr, mk_trigger ctss)]) trigger_eq end + + fun infer_trigger_eq_conv outer_ctxt (ctxt, cvs) ct = + let + val (lhs, rhs) = dest_cond_eq ct + + val vs = map Thm.term_of cvs + val thy = ProofContext.theory_of ctxt + + fun get_mpats ct = + if is_simp_lhs ctxt (Thm.term_of ct) then minimal_pats vs ct + else [] + val gen = Variable.export_terms ctxt outer_ctxt + val filter_mpats = filter (proper_mpat thy gen (Thm.term_of rhs)) + + in insert_trigger_conv (filter_mpats (get_mpats lhs)) ct end + + fun try_trigger_conv cv ct = + if proper_quant false (K false) (Thm.term_of ct) then Conv.all_conv ct + else Conv.try_conv cv ct + + fun infer_trigger_conv ctxt = + if Config.get ctxt SMT_Config.infer_triggers then + try_trigger_conv (U.under_quant_conv (infer_trigger_eq_conv ctxt) ctxt) + else Conv.all_conv in -fun explicit_application ctxt irules = - let - fun sub_conv tb ctxt ct = - (case Term.strip_comb (Thm.term_of ct) of - (Const (n, _), ts) => app_conv tb (const n) (length ts) ctxt - | (Free (n, _), ts) => app_conv tb (free n) (length ts) ctxt - | (Abs _, _) => nary_conv (abs_conv sub_conv tb ctxt) (sub_conv tb ctxt) - | (_, _) => nary_conv Conv.all_conv (sub_conv tb ctxt)) ct - and app_conv tb n i ctxt = - (case Symtab.lookup tb n of - NONE => nary_conv Conv.all_conv (sub_conv tb ctxt) - | SOME j => fun_app_conv tb ctxt (i - j)) - and fun_app_conv tb ctxt i ct = ( - if i = 0 then nary_conv Conv.all_conv (sub_conv tb ctxt) - else - Conv.rewr_conv fun_app_rule then_conv - binop_conv (fun_app_conv tb ctxt (i-1)) (sub_conv tb ctxt)) ct + +fun trigger_conv ctxt = + U.prop_conv (check_trigger_conv ctxt then_conv infer_trigger_conv ctxt) - fun needs_exp_app tab = Term.exists_subterm (fn - Bound _ $ _ => true - | Const (n, _) => Symtab.defined tab (const n) - | Free (n, _) => Symtab.defined tab (free n) - | _ => false) +val setup_trigger = fold B.add_builtin_fun_ext'' + [@{const_name SMT.pat}, @{const_name SMT.nopat}, @{const_name SMT.trigger}] - fun rewrite tab ctxt thm = - if not (needs_exp_app tab (Thm.prop_of thm)) then thm - else Conv.fconv_rule (sub_conv tab ctxt) thm - - val tab = prune (fold (traverse o Thm.prop_of o snd) irules Symtab.empty) - in map (apsnd (rewrite tab ctxt)) irules end end - -(* add missing datatype selectors via hypothetical definitions *) +(** adding quantifier weights **) local - val add = (fn Type (n, _) => Symtab.update (n, ()) | _ => I) + (*** check weight syntax ***) + + val has_no_weight = + not o Term.exists_subterm (fn @{const SMT.weight} => true | _ => false) - fun collect t = - (case Term.strip_comb t of - (Abs (_, T, t), _) => add T #> collect t - | (Const (_, T), ts) => collects T ts - | (Free (_, T), ts) => collects T ts - | _ => I) - and collects T ts = - let val ((Ts, Us), U) = Term.strip_type T |> apfst (chop (length ts)) - in fold add Ts #> add (Us ---> U) #> fold collect ts end + fun is_weight (@{const SMT.weight} $ w $ t) = + (case try HOLogic.dest_number w of + SOME (_, i) => i > 0 andalso has_no_weight t + | _ => false) + | is_weight t = has_no_weight t + + fun proper_trigger (@{const SMT.trigger} $ _ $ t) = is_weight t + | proper_trigger t = has_no_weight t + + fun check_weight_error ctxt t = + error ("SMT weight must be a positive number and must only occur " ^ + "under the top-most quantifier and an optional trigger: " ^ + Syntax.string_of_term ctxt t) - fun add_constructors thy n = - (case Datatype.get_info thy n of - NONE => I - | SOME {descr, ...} => fold (fn (_, (_, _, cs)) => fold (fn (n, ds) => - fold (insert (op =) o pair n) (1 upto length ds)) cs) descr) + fun check_weight_conv ctxt ct = + if U.under_quant proper_trigger (Thm.term_of ct) then Conv.all_conv ct + else check_weight_error ctxt (Thm.term_of ct) + + + (*** insertion of weights ***) + + fun under_trigger_conv cv ct = + (case Thm.term_of ct of + @{const SMT.trigger} $ _ $ _ => Conv.arg_conv cv + | _ => cv) ct - fun add_selector (c as (n, i)) ctxt = - (case Datatype_Selectors.lookup_selector ctxt c of - SOME _ => ctxt - | NONE => - let - val T = Sign.the_const_type (ProofContext.theory_of ctxt) n - val U = Term.body_type T --> nth (Term.binder_types T) (i-1) - in - ctxt - |> yield_singleton Variable.variant_fixes Name.uu - |>> pair ((n, T), i) o rpair U - |-> Context.proof_map o Datatype_Selectors.add_selector - end) + val weight_eq = + mk_meta_eq @{lemma "p = SMT.weight i p" by (simp add: weight_def)} + fun mk_weight_eq w = + let val cv = Thm.dest_arg1 (Thm.rhs_of weight_eq) + in + Thm.instantiate ([], [(cv, Numeral.mk_cnumber @{ctyp int} w)]) weight_eq + end + + fun add_weight_conv NONE _ = Conv.all_conv + | add_weight_conv (SOME weight) ctxt = + let val cv = Conv.rewr_conv (mk_weight_eq weight) + in U.under_quant_conv (K (under_trigger_conv cv)) ctxt end in -fun datatype_selectors irules ctxt = - let - val ns = Symtab.keys (fold (collect o Thm.prop_of o snd) irules Symtab.empty) - val cs = fold (add_constructors (ProofContext.theory_of ctxt)) ns [] - in (irules, fold add_selector cs ctxt) end - (* FIXME: also generate hypothetical definitions for the selectors *) +fun weight_conv weight ctxt = + U.prop_conv (check_weight_conv ctxt then_conv add_weight_conv weight ctxt) + +val setup_weight = B.add_builtin_fun_ext'' @{const_name SMT.weight} end - -(* combined normalization *) +(** combined general normalizations **) -type extra_norm = bool -> (int * thm) list -> Proof.context -> - (int * thm) list * Proof.context - -fun with_context f irules ctxt = (f ctxt irules, ctxt) +fun gen_normalize1_conv ctxt weight = + atomize_conv ctxt then_conv + unfold_special_quants_conv ctxt then_conv + trigger_conv ctxt then_conv + weight_conv weight ctxt -fun normalize extra_norm with_datatypes irules ctxt = - let - fun norm f ctxt' (i, thm) = - if Config.get ctxt' SMT_Config.drop_bad_facts then - (case try (f ctxt') thm of - SOME thm' => SOME (i, thm') - | NONE => (SMT_Config.verbose_msg ctxt' (prefix ("Warning: " ^ - "dropping assumption: ") o Display.string_of_thm ctxt') thm; NONE)) - else SOME (i, f ctxt' thm) - in - irules - |> map (apsnd instantiate_elim) - |> trivial_distinct ctxt - |> rewrite_bool_cases ctxt - |> normalize_numerals ctxt - |> nat_as_int ctxt - |> rpair ctxt - |-> extra_norm with_datatypes - |-> with_context (map_filter o norm normalize_rule) - |-> SMT_Monomorph.monomorph - |-> lift_lambdas - |-> with_context explicit_application - |-> (if with_datatypes then datatype_selectors else pair) - end +fun gen_normalize1 ctxt weight thm = + thm + |> instantiate_elim + |> norm_def + |> Conv.fconv_rule (Thm.beta_conversion true then_conv Thm.eta_conversion) + |> Drule.forall_intr_vars + |> Conv.fconv_rule (gen_normalize1_conv ctxt weight) + +fun drop_fact_warning ctxt = + let val pre = prefix "Warning: dropping assumption: " + in SMT_Config.verbose_msg ctxt (pre o Display.string_of_thm ctxt) end + +fun gen_norm1_safe ctxt (i, (weight, thm)) = + if Config.get ctxt SMT_Config.drop_bad_facts then + (case try (gen_normalize1 ctxt weight) thm of + SOME thm' => SOME (i, thm') + | NONE => (drop_fact_warning ctxt thm; NONE)) + else SOME (i, gen_normalize1 ctxt weight thm) + +fun gen_normalize ctxt iwthms = map_filter (gen_norm1_safe ctxt) iwthms -(* setup *) +(* unfolding of definitions and theory-specific rewritings *) + +(** unfold trivial distincts **) + +local + fun is_trivial_distinct (Const (@{const_name distinct}, _) $ t) = + (case try HOLogic.dest_list t of + SOME [] => true + | SOME [_] => true + | _ => false) + | is_trivial_distinct _ = false + + val thms = map mk_meta_eq @{lemma + "distinct [] = True" + "distinct [x] = True" + "distinct [x, y] = (x ~= y)" + by simp_all} + fun distinct_conv _ = + U.if_true_conv is_trivial_distinct (Conv.rewrs_conv thms) +in + +fun trivial_distinct_conv ctxt = U.if_exists_conv is_trivial_distinct + (Conv.top_conv distinct_conv ctxt) + +end + + +(** rewrite bool case expressions as if expressions **) + +local + fun is_bool_case (Const (@{const_name "bool.bool_case"}, _)) = true + | is_bool_case _ = false + + val thm = mk_meta_eq @{lemma + "bool_case = (%x y P. if P then x else y)" by (rule ext)+ simp} + + fun unfold_conv _ = U.if_true_conv is_bool_case (Conv.rewr_conv thm) +in + +fun rewrite_bool_case_conv ctxt = U.if_exists_conv is_bool_case + (Conv.top_conv unfold_conv ctxt) + +val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"} + +end + + +(** unfold abs, min and max **) + +local + val abs_def = mk_meta_eq @{lemma + "abs = (%a::'a::abs_if. if a < 0 then - a else a)" + by (rule ext) (rule abs_if)} + + val min_def = mk_meta_eq @{lemma "min = (%a b. if a <= b then a else b)" + by (rule ext)+ (rule min_def)} + + val max_def = mk_meta_eq @{lemma "max = (%a b. if a <= b then b else a)" + by (rule ext)+ (rule max_def)} + + val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def), + (@{const_name abs}, abs_def)] + + fun is_builtinT ctxt T = B.is_builtin_typ_ext ctxt (Term.domain_type T) + + fun abs_min_max ctxt (Const (n, T)) = + (case AList.lookup (op =) defs n of + NONE => NONE + | SOME thm => if is_builtinT ctxt T then SOME thm else NONE) + | abs_min_max _ _ = NONE + + fun unfold_amm_conv ctxt ct = + (case abs_min_max ctxt (Thm.term_of ct) of + SOME thm => Conv.rewr_conv thm + | NONE => Conv.all_conv) ct +in + +fun unfold_abs_min_max_conv ctxt = + U.if_exists_conv (is_some o abs_min_max ctxt) + (Conv.top_conv unfold_amm_conv ctxt) + +val setup_abs_min_max = fold (B.add_builtin_fun_ext'' o fst) defs + +end + + +(** embedding of standard natural number operations into integer operations **) + +local + val nat_embedding = @{lemma + "ALL n. nat (int n) = n" + "ALL i. i >= 0 --> int (nat i) = i" + "ALL i. i < 0 --> int (nat i) = 0" + by simp_all} + + val nat_ops = [ + @{const less (nat)}, @{const less_eq (nat)}, + @{const Suc}, @{const plus (nat)}, @{const minus (nat)}, + @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}] + + val nat_consts = nat_ops @ [@{const number_of (nat)}, + @{const zero_class.zero (nat)}, @{const one_class.one (nat)}] + + val nat_int_coercions = [@{const of_nat (int)}, @{const nat}] + + val nat_ops' = nat_int_coercions @ nat_ops + + val is_nat_const = member (op aconv) nat_consts + + val expands = map mk_meta_eq @{lemma + "0 = nat 0" + "1 = nat 1" + "(number_of :: int => nat) = (%i. nat (number_of i))" + "op < = (%a b. int a < int b)" + "op <= = (%a b. int a <= int b)" + "Suc = (%a. nat (int a + 1))" + "op + = (%a b. nat (int a + int b))" + "op - = (%a b. nat (int a - int b))" + "op * = (%a b. nat (int a * int b))" + "op div = (%a b. nat (int a div int b))" + "op mod = (%a b. nat (int a mod int b))" + by (auto intro!: ext simp add: nat_mult_distrib nat_div_distrib + nat_mod_distrib)} + + val ints = map mk_meta_eq @{lemma + "int 0 = 0" + "int 1 = 1" + "int (Suc n) = int n + 1" + "int (n + m) = int n + int m" + "int (n - m) = int (nat (int n - int m))" + "int (n * m) = int n * int m" + "int (n div m) = int n div int m" + "int (n mod m) = int n mod int m" + "int (if P then n else m) = (if P then int n else int m)" + by (auto simp add: int_mult zdiv_int zmod_int)} + + fun mk_number_eq ctxt i lhs = + let + val eq = U.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i) + val ss = HOL_ss + addsimps [@{thm Nat_Numeral.int_nat_number_of}] + addsimps @{thms neg_simps} + fun tac _ = Simplifier.simp_tac (Simplifier.context ctxt ss) 1 + in Goal.norm_result (Goal.prove_internal [] eq tac) end + + fun expand_head_conv cv ct = + (case Thm.term_of ct of + _ $ _ => + Conv.fun_conv (expand_head_conv cv) then_conv + Thm.beta_conversion false + | _ => cv) ct + + fun int_conv ctxt ct = + (case Thm.term_of ct of + @{const of_nat (int)} $ (n as (@{const number_of (nat)} $ _)) => + Conv.rewr_conv (mk_number_eq ctxt (snd (HOLogic.dest_number n)) ct) + | @{const of_nat (int)} $ _ => + (Conv.rewrs_conv ints then_conv Conv.sub_conv ints_conv ctxt) else_conv + Conv.top_sweep_conv nat_conv ctxt + | _ => Conv.no_conv) ct + + and ints_conv ctxt = Conv.top_sweep_conv int_conv ctxt + + and expand_conv ctxt = + U.if_conv (not o is_nat_const o Term.head_of) Conv.no_conv + (expand_head_conv (Conv.rewrs_conv expands) then_conv ints_conv ctxt) + + and nat_conv ctxt = U.if_exists_conv is_nat_const + (Conv.top_sweep_conv expand_conv ctxt) + + val uses_nat_int = Term.exists_subterm (member (op aconv) nat_int_coercions) +in + +val nat_as_int_conv = nat_conv + +fun add_nat_embedding thms = + if exists (uses_nat_int o Thm.prop_of) thms then (thms, nat_embedding) + else (thms, []) + +val setup_nat_as_int = + B.add_builtin_typ_ext (@{typ nat}, K true) #> + fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops' + +end + + +(** normalize numerals **) + +local + (* + rewrite negative numerals into positive numerals, + rewrite Numeral0 into 0 + rewrite Numeral1 into 1 + *) + + fun is_strange_number ctxt (t as Const (@{const_name number_of}, _) $ _) = + (case try HOLogic.dest_number t of + SOME (_, i) => B.is_builtin_num ctxt t andalso i < 2 + | NONE => false) + | is_strange_number _ _ = false + + val pos_num_ss = HOL_ss + addsimps [@{thm Int.number_of_minus}, @{thm Int.number_of_Min}] + addsimps [@{thm Int.number_of_Pls}, @{thm Int.numeral_1_eq_1}] + addsimps @{thms Int.pred_bin_simps} + addsimps @{thms Int.normalize_bin_simps} + addsimps @{lemma + "Int.Min = - Int.Bit1 Int.Pls" + "Int.Bit0 (- Int.Pls) = - Int.Pls" + "Int.Bit0 (- k) = - Int.Bit0 k" + "Int.Bit1 (- k) = - Int.Bit1 (Int.pred k)" + by simp_all (simp add: pred_def)} + + fun norm_num_conv ctxt = U.if_conv (is_strange_number ctxt) + (Simplifier.rewrite (Simplifier.context ctxt pos_num_ss)) Conv.no_conv +in + +fun normalize_numerals_conv ctxt = U.if_exists_conv (is_strange_number ctxt) + (Conv.top_sweep_conv norm_num_conv ctxt) + +end + + +(** combined unfoldings and rewritings **) + +fun unfold_conv ctxt = + trivial_distinct_conv ctxt then_conv + rewrite_bool_case_conv ctxt then_conv + unfold_abs_min_max_conv ctxt then_conv + nat_as_int_conv ctxt then_conv + normalize_numerals_conv ctxt then_conv + Thm.beta_conversion true + +fun burrow_ids f ithms = + let + val (is, thms) = split_list ithms + val (thms', extra_thms) = f thms + in (is ~~ thms') @ map (pair ~1) extra_thms end + +fun unfold ctxt = + burrow_ids (map (Conv.fconv_rule (unfold_conv ctxt)) #> add_nat_embedding) + + + +(* overall normalization *) + +type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list + +structure Extra_Norms = Generic_Data +( + type T = extra_norm U.dict + val empty = [] + val extend = I + val merge = U.dict_merge fst +) + +fun add_extra_norm (cs, norm) = Extra_Norms.map (U.dict_update (cs, norm)) + +fun apply_extra_norms ctxt = + let + val cs = SMT_Config.solver_class_of ctxt + val es = U.dict_lookup (Extra_Norms.get (Context.Proof ctxt)) cs + in burrow_ids (fold (fn e => e ctxt) es o rpair []) end + +fun normalize ctxt iwthms = + iwthms + |> gen_normalize ctxt + |> unfold ctxt + |> apply_extra_norms ctxt val setup = Context.theory_map ( + setup_atomize #> + setup_unfolded_quants #> + setup_trigger #> + setup_weight #> setup_bool_case #> - setup_nat_as_int #> - setup_unfolded_quants #> - setup_atomize) + setup_abs_min_max #> + setup_nat_as_int) end diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/SMT/smt_solver.ML --- a/src/HOL/Tools/SMT/smt_solver.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/SMT/smt_solver.ML Wed Dec 15 08:39:24 2010 +0100 @@ -9,7 +9,6 @@ (*configuration*) type interface = { class: SMT_Utils.class, - extra_norm: SMT_Normalize.extra_norm, translate: SMT_Translate.config } datatype outcome = Unsat | Sat | Unknown type solver_config = { @@ -26,8 +25,8 @@ default_max_relevant: int } (*registry*) - type solver = bool option -> Proof.context -> (int * thm) list -> - int list * thm + type solver = bool option -> Proof.context -> + (int * (int option * thm)) list -> int list * thm val add_solver: solver_config -> theory -> theory val solver_name_of: Proof.context -> string val solver_of: Proof.context -> solver @@ -37,7 +36,8 @@ val default_max_relevant: Proof.context -> string -> int (*filter*) - val smt_filter: bool -> Time.time -> Proof.state -> ('a * thm) list -> int -> + val smt_filter: bool -> Time.time -> Proof.state -> + ('a * (int option * thm)) list -> int -> {outcome: SMT_Failure.failure option, used_facts: ('a * thm) list, run_time_in_msecs: int option} @@ -59,7 +59,6 @@ type interface = { class: SMT_Utils.class, - extra_norm: SMT_Normalize.extra_norm, translate: SMT_Translate.config } datatype outcome = Unsat | Sat | Unknown @@ -176,7 +175,7 @@ Pretty.big_list "functions:" (map p_term (Symtab.dest terms))])) () end -fun invoke translate_config name cmd options irules ctxt = +fun invoke translate_config name cmd options ithms ctxt = let val args = C.solver_options_of ctxt @ options ctxt val comments = ("solver: " ^ name) :: @@ -184,7 +183,7 @@ ("random seed: " ^ string_of_int (Config.get ctxt C.random_seed)) :: "arguments:" :: args in - irules + ithms |> tap (trace_assms ctxt) |> SMT_Translate.translate translate_config ctxt comments ||> tap (trace_recon_data ctxt) @@ -197,22 +196,23 @@ else discharge_definitions (@{thm reflexive} RS thm) fun set_has_datatypes with_datatypes translate = - let - val {prefixes, header, is_fol, has_datatypes, serialize} = translate - val with_datatypes' = has_datatypes andalso with_datatypes - val translate' = {prefixes=prefixes, header=header, is_fol=is_fol, - has_datatypes=with_datatypes', serialize=serialize} - in (with_datatypes', translate') end + let val {prefixes, header, is_fol, has_datatypes, serialize} = translate + in + {prefixes=prefixes, header=header, is_fol=is_fol, + has_datatypes=has_datatypes andalso with_datatypes, serialize=serialize} + end -fun trace_assumptions ctxt irules idxs = +fun trace_assumptions ctxt iwthms idxs = let - val thms = filter (fn i => i >= 0) idxs - |> map_filter (AList.lookup (op =) irules) + val wthms = + idxs + |> filter (fn i => i >= 0) + |> map_filter (AList.lookup (op =) iwthms) in - if Config.get ctxt C.trace_used_facts andalso length thms > 0 + if Config.get ctxt C.trace_used_facts andalso length wthms > 0 then tracing (Pretty.string_of (Pretty.big_list "SMT used facts:" - (map (Display.pretty_thm ctxt) thms))) + (map (Display.pretty_thm ctxt o snd) wthms))) else () end @@ -220,7 +220,8 @@ (* registry *) -type solver = bool option -> Proof.context -> (int * thm) list -> int list * thm +type solver = bool option -> Proof.context -> (int * (int option * thm)) list -> + int list * thm type solver_info = { env_var: string, @@ -231,22 +232,22 @@ (int list * thm) * Proof.context, default_max_relevant: int } -fun gen_solver name (info : solver_info) rm ctxt irules = +fun gen_solver name (info : solver_info) rm ctxt iwthms = let val {env_var, is_remote, options, interface, reconstruct, ...} = info - val {extra_norm, translate, ...} = interface - val (with_datatypes, translate') = - set_has_datatypes (Config.get ctxt C.datatypes) translate + val {translate, ...} = interface + val translate' = set_has_datatypes (Config.get ctxt C.datatypes) translate val cmd = (rm, env_var, is_remote, name) in - (irules, ctxt) - |-> SMT_Normalize.normalize extra_norm with_datatypes + SMT_Normalize.normalize ctxt iwthms + |> rpair ctxt + |-> SMT_Monomorph.monomorph |-> invoke translate' name cmd options |-> reconstruct |-> (fn (idxs, thm) => fn ctxt' => thm |> singleton (ProofContext.export ctxt' ctxt) |> discharge_definitions - |> tap (fn _ => trace_assumptions ctxt irules idxs) + |> tap (fn _ => trace_assumptions ctxt iwthms idxs) |> pair idxs) end @@ -330,38 +331,45 @@ | TVar (_, []) => true | _ => false)) -fun smt_solver rm ctxt irules = +fun smt_solver rm ctxt iwthms = (* without this test, we would run into problems when atomizing the rules: *) - if exists (has_topsort o Thm.prop_of o snd) irules then + if exists (has_topsort o Thm.prop_of o snd o snd) iwthms then raise SMT_Failure.SMT (SMT_Failure.Other_Failure ("proof state " ^ "contains the universal sort {}")) - else solver_of ctxt rm ctxt irules + else solver_of ctxt rm ctxt iwthms val cnot = Thm.cterm_of @{theory} @{const Not} -fun smt_filter run_remote time_limit st xrules i = +fun mk_result outcome xrules = + { outcome = outcome, used_facts = xrules, run_time_in_msecs = NONE } + +fun smt_filter run_remote time_limit st xwrules i = let - val {facts, goal, ...} = Proof.goal st val ctxt = Proof.context_of st |> Config.put C.oracle false |> Config.put C.timeout (Time.toReal time_limit) |> Config.put C.drop_bad_facts true |> Config.put C.filter_only_facts true + + val {facts, goal, ...} = Proof.goal st val ({context=ctxt', prems, concl, ...}, _) = Subgoal.focus ctxt i goal fun negate ct = Thm.dest_comb ct ||> Thm.capply cnot |-> Thm.capply val cprop = negate (Thm.rhs_of (SMT_Normalize.atomize_conv ctxt' concl)) - val irs = map (pair ~1) (Thm.assume cprop :: prems @ facts) - val rm = SOME run_remote + + val (xs, wthms) = split_list xwrules + val xrules = xs ~~ map snd wthms in - (xrules, map snd xrules) - ||> distinct (op =) o fst o smt_solver rm ctxt' o append irs o map_index I - |-> map_filter o try o nth - |> (fn xs => {outcome=NONE, used_facts=if solver_name_of ctxt = "z3" (* FIXME *) then xs - else xrules, run_time_in_msecs=NONE}) + wthms + |> map_index I + |> append (map (pair ~1 o pair NONE) (Thm.assume cprop :: prems @ facts)) + |> smt_solver (SOME run_remote) ctxt' + |> distinct (op =) o fst + |> map_filter (try (nth xrules)) + |> (if solver_name_of ctxt = "z3" (* FIXME *) then I else K xrules) + |> mk_result NONE end - handle SMT_Failure.SMT fail => {outcome=SOME fail, used_facts=[], - run_time_in_msecs=NONE} + handle SMT_Failure.SMT fail => mk_result (SOME fail) [] (* FIXME: measure runtime *) @@ -373,18 +381,18 @@ THEN' Tactic.rtac @{thm ccontr} THEN' SUBPROOF (fn {context=ctxt', prems, ...} => let - fun solve irules = snd (smt_solver NONE ctxt' irules) + fun solve iwthms = snd (smt_solver NONE ctxt' iwthms) val tag = "Solver " ^ C.solver_of ctxt' ^ ": " val str_of = prefix tag o SMT_Failure.string_of_failure ctxt' - fun safe_solve irules = - if pass_exns then SOME (solve irules) - else (SOME (solve irules) + fun safe_solve iwthms = + if pass_exns then SOME (solve iwthms) + else (SOME (solve iwthms) handle SMT_Failure.SMT (fail as SMT_Failure.Counterexample _) => (C.verbose_msg ctxt' str_of fail; NONE) | SMT_Failure.SMT fail => (C.trace_msg ctxt' str_of fail; NONE)) in - safe_solve (map (pair ~1) (rules @ prems)) + safe_solve (map (pair ~1 o pair NONE) (rules @ prems)) |> (fn SOME thm => Tactic.rtac thm 1 | _ => Tactical.no_tac) end) ctxt diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/SMT/smt_utils.ML --- a/src/HOL/Tools/SMT/smt_utils.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/SMT/smt_utils.ML Wed Dec 15 08:39:24 2010 +0100 @@ -25,6 +25,7 @@ (*terms*) val dest_conj: term -> term * term val dest_disj: term -> term * term + val under_quant: (term -> 'a) -> term -> 'a (*patterns and instantiations*) val mk_const_pat: theory -> string -> (ctyp -> 'a) -> 'a * cterm @@ -48,7 +49,10 @@ (*conversions*) val if_conv: (term -> bool) -> conv -> conv -> conv val if_true_conv: (term -> bool) -> conv -> conv + val if_exists_conv: (term -> bool) -> conv -> conv val binders_conv: (Proof.context -> conv) -> Proof.context -> conv + val under_quant_conv: (Proof.context * cterm list -> conv) -> + Proof.context -> conv val prop_conv: conv -> conv end @@ -110,6 +114,12 @@ fun dest_disj (@{const HOL.disj} $ t $ u) = (t, u) | dest_disj t = raise TERM ("not a disjunction", [t]) +fun under_quant f t = + (case t of + Const (@{const_name All}, _) $ Abs (_, _, u) => under_quant f u + | Const (@{const_name Ex}, _) $ Abs (_, _, u) => under_quant f u + | _ => f t) + (* patterns and instantiations *) @@ -164,9 +174,23 @@ fun if_true_conv pred cv = if_conv pred cv Conv.all_conv +fun if_exists_conv pred = if_true_conv (Term.exists_subterm pred) + fun binders_conv cv ctxt = Conv.binder_conv (binders_conv cv o snd) ctxt else_conv cv ctxt +fun under_quant_conv cv ctxt = + let + fun quant_conv inside ctxt cvs ct = + (case Thm.term_of ct of + Const (@{const_name All}, _) $ Abs _ => + Conv.binder_conv (under_conv cvs) ctxt + | Const (@{const_name Ex}, _) $ Abs _ => + Conv.binder_conv (under_conv cvs) ctxt + | _ => if inside then cv (ctxt, cvs) else Conv.all_conv) ct + and under_conv cvs (cv, ctxt) = quant_conv true ctxt (cv :: cvs) + in quant_conv false ctxt [] end + fun prop_conv cv ct = (case Thm.term_of ct of @{const Trueprop} $ _ => Conv.arg_conv cv ct diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/SMT/smtlib_interface.ML --- a/src/HOL/Tools/SMT/smtlib_interface.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/SMT/smtlib_interface.ML Wed Dec 15 08:39:24 2010 +0100 @@ -23,77 +23,6 @@ val smtlibC = ["smtlib"] - -(* facts about uninterpreted constants *) - -infix 2 ?? -fun (ex ?? f) irules = irules |> exists (ex o Thm.prop_of o snd) irules ? f - - -(** pairs **) - -val pair_rules = [@{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}] - -val pair_type = (fn Type (@{type_name Product_Type.prod}, _) => true | _ => false) -val exists_pair_type = Term.exists_type (Term.exists_subtype pair_type) - -val add_pair_rules = exists_pair_type ?? append (map (pair ~1) pair_rules) - - -(** function update **) - -val fun_upd_rules = [@{thm fun_upd_same}, @{thm fun_upd_apply}] - -val is_fun_upd = (fn Const (@{const_name fun_upd}, _) => true | _ => false) -val exists_fun_upd = Term.exists_subterm is_fun_upd - -val add_fun_upd_rules = exists_fun_upd ?? append (map (pair ~1) fun_upd_rules) - - -(** abs/min/max **) - -val exists_abs_min_max = Term.exists_subterm (fn - Const (@{const_name abs}, _) => true - | Const (@{const_name min}, _) => true - | Const (@{const_name max}, _) => true - | _ => false) - -val unfold_abs_conv = Conv.rewr_conv (mk_meta_eq @{thm abs_if}) -val unfold_min_conv = Conv.rewr_conv (mk_meta_eq @{thm min_def}) -val unfold_max_conv = Conv.rewr_conv (mk_meta_eq @{thm max_def}) - -fun expand_conv cv = N.eta_expand_conv (K cv) -fun expand2_conv cv = N.eta_expand_conv (N.eta_expand_conv (K cv)) - -fun unfold_def_conv ctxt ct = - (case Thm.term_of ct of - Const (@{const_name abs}, _) $ _ => unfold_abs_conv - | Const (@{const_name abs}, _) => expand_conv unfold_abs_conv ctxt - | Const (@{const_name min}, _) $ _ $ _ => unfold_min_conv - | Const (@{const_name min}, _) $ _ => expand_conv unfold_min_conv ctxt - | Const (@{const_name min}, _) => expand2_conv unfold_min_conv ctxt - | Const (@{const_name max}, _) $ _ $ _ => unfold_max_conv - | Const (@{const_name max}, _) $ _ => expand_conv unfold_max_conv ctxt - | Const (@{const_name max}, _) => expand2_conv unfold_max_conv ctxt - | _ => Conv.all_conv) ct - -fun unfold_abs_min_max_defs ctxt thm = - if exists_abs_min_max (Thm.prop_of thm) - then Conv.fconv_rule (Conv.top_conv unfold_def_conv ctxt) thm - else thm - - -(** include additional facts **) - -fun extra_norm has_datatypes irules ctxt = - irules - |> not has_datatypes ? add_pair_rules - |> add_fun_upd_rules - |> map (apsnd (unfold_abs_min_max_defs ctxt)) - |> rpair ctxt - - - (* builtins *) local @@ -131,7 +60,6 @@ end - (* serialization *) (** header **) @@ -215,12 +143,10 @@ |> Buffer.content - (* interface *) val interface = { class = smtlibC, - extra_norm = extra_norm, translate = { prefixes = { sort_prefix = "S", diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/SMT/z3_interface.ML --- a/src/HOL/Tools/SMT/z3_interface.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/SMT/z3_interface.ML Wed Dec 15 08:39:24 2010 +0100 @@ -43,12 +43,13 @@ | is_int_div_mod @{const mod (int)} = true | is_int_div_mod _ = false - fun add_div_mod irules = - if exists (Term.exists_subterm is_int_div_mod o Thm.prop_of o snd) irules - then [(~1, @{thm div_by_z3div}), (~1, @{thm mod_by_z3mod})] @ irules - else irules + val have_int_div_mod = + exists (Term.exists_subterm is_int_div_mod o Thm.prop_of) - fun extra_norm' has_datatypes = extra_norm has_datatypes o add_div_mod + fun add_div_mod _ (thms, extra_thms) = + if have_int_div_mod thms orelse have_int_div_mod extra_thms then + (thms, @{thm div_by_z3div} :: @{thm mod_by_z3mod} :: extra_thms) + else (thms, extra_thms) val setup_builtins = B.add_builtin_fun' smtlib_z3C (@{const z3div}, "div") #> @@ -57,7 +58,6 @@ val interface = { class = smtlib_z3C, - extra_norm = extra_norm', translate = { prefixes = prefixes, is_fol = is_fol, @@ -65,7 +65,9 @@ has_datatypes = true, serialize = serialize}} -val setup = Context.theory_map setup_builtins +val setup = Context.theory_map ( + setup_builtins #> + SMT_Normalize.add_extra_norm (smtlib_z3C, add_div_mod)) end diff -r 4a9eec045f2a -r e0bd443c0fdd src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML --- a/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Wed Dec 15 08:39:24 2010 +0100 +++ b/src/HOL/Tools/Sledgehammer/sledgehammer_provers.ML Wed Dec 15 08:39:24 2010 +0100 @@ -528,7 +528,7 @@ #> Config.put SMT_Config.monomorph_limit smt_monomorph_limit val state = state |> Proof.map_context repair_context val thy = Proof.theory_of state - val facts = facts |> map (apsnd (Thm.transfer thy) o untranslated_fact) + val facts = facts |> map (apsnd (pair NONE o Thm.transfer thy) o untranslated_fact) val {outcome, used_facts, run_time_in_msecs} = smt_filter_loop params remote state subgoal facts val (chained_lemmas, other_lemmas) = split_used_facts (map fst used_facts)