src/HOL/Tools/SMT/z3_proof_tools.ML
changeset 58055 625bdd5c70b2
parent 58054 1d9edd486479
child 58056 fc6dd578d506
equal deleted inserted replaced
58054:1d9edd486479 58055:625bdd5c70b2
     1 (*  Title:      HOL/Tools/SMT/z3_proof_tools.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Helper functions required for Z3 proof reconstruction.
       
     5 *)
       
     6 
       
     7 signature Z3_PROOF_TOOLS =
       
     8 sig
       
     9   (*modifying terms*)
       
    10   val as_meta_eq: cterm -> cterm
       
    11 
       
    12   (*theorem nets*)
       
    13   val thm_net_of: ('a -> thm) -> 'a list -> 'a Net.net
       
    14   val net_instances: (int * thm) Net.net -> cterm -> (int * thm) list
       
    15   val net_instance: thm Net.net -> cterm -> thm option
       
    16 
       
    17   (*proof combinators*)
       
    18   val under_assumption: (thm -> thm) -> cterm -> thm
       
    19   val with_conv: conv -> (cterm -> thm) -> cterm -> thm
       
    20   val discharge: thm -> thm -> thm
       
    21   val varify: string list -> thm -> thm
       
    22   val unfold_eqs: Proof.context -> thm list -> conv
       
    23   val match_instantiate: (cterm -> cterm) -> cterm -> thm -> thm
       
    24   val by_tac: Proof.context -> (int -> tactic) -> cterm -> thm
       
    25   val make_hyp_def: thm -> Proof.context -> thm * Proof.context
       
    26   val by_abstraction: int -> bool * bool -> Proof.context -> thm list ->
       
    27     (Proof.context -> cterm -> thm) -> cterm -> thm
       
    28 
       
    29   (*a faster COMP*)
       
    30   type compose_data
       
    31   val precompose: (cterm -> cterm list) -> thm -> compose_data
       
    32   val precompose2: (cterm -> cterm * cterm) -> thm -> compose_data
       
    33   val compose: compose_data -> thm -> thm
       
    34 
       
    35   (*unfolding of 'distinct'*)
       
    36   val unfold_distinct_conv: conv
       
    37 
       
    38   (*simpset*)
       
    39   val add_simproc: Simplifier.simproc -> Context.generic -> Context.generic
       
    40   val make_simpset: Proof.context -> thm list -> simpset
       
    41 end
       
    42 
       
    43 structure Z3_Proof_Tools: Z3_PROOF_TOOLS =
       
    44 struct
       
    45 
       
    46 
       
    47 
       
    48 (* modifying terms *)
       
    49 
       
    50 fun as_meta_eq ct =
       
    51   uncurry SMT_Utils.mk_cequals (Thm.dest_binop (SMT_Utils.dest_cprop ct))
       
    52 
       
    53 
       
    54 
       
    55 (* theorem nets *)
       
    56 
       
    57 fun thm_net_of f xthms =
       
    58   let fun insert xthm = Net.insert_term (K false) (Thm.prop_of (f xthm), xthm)
       
    59   in fold insert xthms Net.empty end
       
    60 
       
    61 fun maybe_instantiate ct thm =
       
    62   try Thm.first_order_match (Thm.cprop_of thm, ct)
       
    63   |> Option.map (fn inst => Thm.instantiate inst thm)
       
    64 
       
    65 local
       
    66   fun instances_from_net match f net ct =
       
    67     let
       
    68       val lookup = if match then Net.match_term else Net.unify_term
       
    69       val xthms = lookup net (Thm.term_of ct)
       
    70       fun select ct = map_filter (f (maybe_instantiate ct)) xthms 
       
    71       fun select' ct =
       
    72         let val thm = Thm.trivial ct
       
    73         in map_filter (f (try (fn rule => rule COMP thm))) xthms end
       
    74     in (case select ct of [] => select' ct | xthms' => xthms') end
       
    75 in
       
    76 
       
    77 fun net_instances net =
       
    78   instances_from_net false (fn f => fn (i, thm) => Option.map (pair i) (f thm))
       
    79     net
       
    80 
       
    81 fun net_instance net = try hd o instances_from_net true I net
       
    82 
       
    83 end
       
    84 
       
    85 
       
    86 
       
    87 (* proof combinators *)
       
    88 
       
    89 fun under_assumption f ct =
       
    90   let val ct' = SMT_Utils.mk_cprop ct
       
    91   in Thm.implies_intr ct' (f (Thm.assume ct')) end
       
    92 
       
    93 fun with_conv conv prove ct =
       
    94   let val eq = Thm.symmetric (conv ct)
       
    95   in Thm.equal_elim eq (prove (Thm.lhs_of eq)) end
       
    96 
       
    97 fun discharge p pq = Thm.implies_elim pq p
       
    98 
       
    99 fun varify vars = Drule.generalize ([], vars)
       
   100 
       
   101 fun unfold_eqs _ [] = Conv.all_conv
       
   102   | unfold_eqs ctxt eqs =
       
   103       Conv.top_sweep_conv (K (Conv.rewrs_conv eqs)) ctxt
       
   104 
       
   105 fun match_instantiate f ct thm =
       
   106   Thm.instantiate (Thm.match (f (Thm.cprop_of thm), ct)) thm
       
   107 
       
   108 fun by_tac ctxt tac ct = Goal.norm_result ctxt (Goal.prove_internal ctxt [] ct (K (tac 1)))
       
   109 
       
   110 (*
       
   111    |- c x == t x ==> P (c x)
       
   112   ---------------------------
       
   113       c == t |- P (c x)
       
   114 *) 
       
   115 fun make_hyp_def thm ctxt =
       
   116   let
       
   117     val (lhs, rhs) = Thm.dest_binop (Thm.cprem_of thm 1)
       
   118     val (cf, cvs) = Drule.strip_comb lhs
       
   119     val eq = SMT_Utils.mk_cequals cf (fold_rev Thm.lambda cvs rhs)
       
   120     fun apply cv th =
       
   121       Thm.combination th (Thm.reflexive cv)
       
   122       |> Conv.fconv_rule (Conv.arg_conv (Thm.beta_conversion false))
       
   123   in
       
   124     yield_singleton Assumption.add_assumes eq ctxt
       
   125     |>> Thm.implies_elim thm o fold apply cvs
       
   126   end
       
   127 
       
   128 
       
   129 
       
   130 (* abstraction *)
       
   131 
       
   132 local
       
   133 
       
   134 fun abs_context ctxt = (ctxt, Termtab.empty, 1, false)
       
   135 
       
   136 fun context_of (ctxt, _, _, _) = ctxt
       
   137 
       
   138 fun replace (_, (cv, ct)) = Thm.forall_elim ct o Thm.forall_intr cv
       
   139 
       
   140 fun abs_instantiate (_, tab, _, beta_norm) =
       
   141   fold replace (Termtab.dest tab) #>
       
   142   beta_norm ? Conv.fconv_rule (Thm.beta_conversion true)
       
   143 
       
   144 fun lambda_abstract cvs t =
       
   145   let
       
   146     val frees = map Free (Term.add_frees t [])
       
   147     val cvs' = filter (fn cv => member (op aconv) frees (Thm.term_of cv)) cvs
       
   148     val vs = map (Term.dest_Free o Thm.term_of) cvs'
       
   149   in (fold_rev absfree vs t, cvs') end
       
   150 
       
   151 fun fresh_abstraction (_, cvs) ct (cx as (ctxt, tab, idx, beta_norm)) =
       
   152   let val (t, cvs') = lambda_abstract cvs (Thm.term_of ct)
       
   153   in
       
   154     (case Termtab.lookup tab t of
       
   155       SOME (cv, _) => (Drule.list_comb (cv, cvs'), cx)
       
   156     | NONE =>
       
   157         let
       
   158           val (n, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
       
   159           val cv = SMT_Utils.certify ctxt'
       
   160             (Free (n, map SMT_Utils.typ_of cvs' ---> SMT_Utils.typ_of ct))
       
   161           val cu = Drule.list_comb (cv, cvs')
       
   162           val e = (t, (cv, fold_rev Thm.lambda cvs' ct))
       
   163           val beta_norm' = beta_norm orelse not (null cvs')
       
   164         in (cu, (ctxt', Termtab.update e tab, idx + 1, beta_norm')) end)
       
   165   end
       
   166 
       
   167 fun abs_comb f g dcvs ct =
       
   168   let val (cf, cu) = Thm.dest_comb ct
       
   169   in f dcvs cf ##>> g dcvs cu #>> uncurry Thm.apply end
       
   170 
       
   171 fun abs_arg f = abs_comb (K pair) f
       
   172 
       
   173 fun abs_args f dcvs ct =
       
   174   (case Thm.term_of ct of
       
   175     _ $ _ => abs_comb (abs_args f) f dcvs ct
       
   176   | _ => pair ct)
       
   177 
       
   178 fun abs_list f g dcvs ct =
       
   179   (case Thm.term_of ct of
       
   180     Const (@{const_name Nil}, _) => pair ct
       
   181   | Const (@{const_name Cons}, _) $ _ $ _ =>
       
   182       abs_comb (abs_arg f) (abs_list f g) dcvs ct
       
   183   | _ => g dcvs ct)
       
   184 
       
   185 fun abs_abs f (depth, cvs) ct =
       
   186   let val (cv, cu) = Thm.dest_abs NONE ct
       
   187   in f (depth, cv :: cvs) cu #>> Thm.lambda cv end
       
   188 
       
   189 val is_atomic =
       
   190   (fn Free _ => true | Var _ => true | Bound _ => true | _ => false)
       
   191 
       
   192 fun abstract depth (ext_logic, with_theories) =
       
   193   let
       
   194     fun abstr1 cvs ct = abs_arg abstr cvs ct
       
   195     and abstr2 cvs ct = abs_comb abstr1 abstr cvs ct
       
   196     and abstr3 cvs ct = abs_comb abstr2 abstr cvs ct
       
   197     and abstr_abs cvs ct = abs_arg (abs_abs abstr) cvs ct
       
   198 
       
   199     and abstr (dcvs as (d, cvs)) ct =
       
   200       (case Thm.term_of ct of
       
   201         @{const Trueprop} $ _ => abstr1 dcvs ct
       
   202       | @{const Pure.imp} $ _ $ _ => abstr2 dcvs ct
       
   203       | @{const True} => pair ct
       
   204       | @{const False} => pair ct
       
   205       | @{const Not} $ _ => abstr1 dcvs ct
       
   206       | @{const HOL.conj} $ _ $ _ => abstr2 dcvs ct
       
   207       | @{const HOL.disj} $ _ $ _ => abstr2 dcvs ct
       
   208       | @{const HOL.implies} $ _ $ _ => abstr2 dcvs ct
       
   209       | Const (@{const_name HOL.eq}, _) $ _ $ _ => abstr2 dcvs ct
       
   210       | Const (@{const_name distinct}, _) $ _ =>
       
   211           if ext_logic then abs_arg (abs_list abstr fresh_abstraction) dcvs ct
       
   212           else fresh_abstraction dcvs ct
       
   213       | Const (@{const_name If}, _) $ _ $ _ $ _ =>
       
   214           if ext_logic then abstr3 dcvs ct else fresh_abstraction dcvs ct
       
   215       | Const (@{const_name All}, _) $ _ =>
       
   216           if ext_logic then abstr_abs dcvs ct else fresh_abstraction dcvs ct
       
   217       | Const (@{const_name Ex}, _) $ _ =>
       
   218           if ext_logic then abstr_abs dcvs ct else fresh_abstraction dcvs ct
       
   219       | t => (fn cx =>
       
   220           if is_atomic t orelse can HOLogic.dest_number t then (ct, cx)
       
   221           else if with_theories andalso
       
   222             Z3_Interface.is_builtin_theory_term (context_of cx) t
       
   223           then abs_args abstr dcvs ct cx
       
   224           else if d = 0 then fresh_abstraction dcvs ct cx
       
   225           else
       
   226             (case Term.strip_comb t of
       
   227               (Const _, _) => abs_args abstr (d-1, cvs) ct cx
       
   228             | (Free _, _) => abs_args abstr (d-1, cvs) ct cx
       
   229             | _ => fresh_abstraction dcvs ct cx)))
       
   230   in abstr (depth, []) end
       
   231 
       
   232 val cimp = Thm.cterm_of @{theory} @{const Pure.imp}
       
   233 
       
   234 fun deepen depth f x =
       
   235   if depth = 0 then f depth x
       
   236   else (case try (f depth) x of SOME y => y | NONE => deepen (depth - 1) f x)
       
   237 
       
   238 fun with_prems depth thms f ct =
       
   239   fold_rev (Thm.mk_binop cimp o Thm.cprop_of) thms ct
       
   240   |> deepen depth f
       
   241   |> fold (fn prem => fn th => Thm.implies_elim th prem) thms
       
   242 
       
   243 in
       
   244 
       
   245 fun by_abstraction depth mode ctxt thms prove =
       
   246   with_prems depth thms (fn d => fn ct =>
       
   247     let val (cu, cx) = abstract d mode ct (abs_context ctxt)
       
   248     in abs_instantiate cx (prove (context_of cx) cu) end)
       
   249 
       
   250 end
       
   251 
       
   252 
       
   253 
       
   254 (* a faster COMP *)
       
   255 
       
   256 type compose_data = cterm list * (cterm -> cterm list) * thm
       
   257 
       
   258 fun list2 (x, y) = [x, y]
       
   259 
       
   260 fun precompose f rule = (f (Thm.cprem_of rule 1), f, rule)
       
   261 fun precompose2 f rule = precompose (list2 o f) rule
       
   262 
       
   263 fun compose (cvs, f, rule) thm =
       
   264   discharge thm (Thm.instantiate ([], cvs ~~ f (Thm.cprop_of thm)) rule)
       
   265 
       
   266 
       
   267 
       
   268 (* unfolding of 'distinct' *)
       
   269 
       
   270 local
       
   271   val set1 = @{lemma "x ~: set [] == ~False" by simp}
       
   272   val set2 = @{lemma "x ~: set [x] == False" by simp}
       
   273   val set3 = @{lemma "x ~: set [y] == x ~= y" by simp}
       
   274   val set4 = @{lemma "x ~: set (x # ys) == False" by simp}
       
   275   val set5 = @{lemma "x ~: set (y # ys) == x ~= y & x ~: set ys" by simp}
       
   276 
       
   277   fun set_conv ct =
       
   278     (Conv.rewrs_conv [set1, set2, set3, set4] else_conv
       
   279     (Conv.rewr_conv set5 then_conv Conv.arg_conv set_conv)) ct
       
   280 
       
   281   val dist1 = @{lemma "distinct [] == ~False" by (simp add: distinct_def)}
       
   282   val dist2 = @{lemma "distinct [x] == ~False" by (simp add: distinct_def)}
       
   283   val dist3 = @{lemma "distinct (x # xs) == x ~: set xs & distinct xs"
       
   284     by (simp add: distinct_def)}
       
   285 
       
   286   fun binop_conv cv1 cv2 = Conv.combination_conv (Conv.arg_conv cv1) cv2
       
   287 in
       
   288 fun unfold_distinct_conv ct =
       
   289   (Conv.rewrs_conv [dist1, dist2] else_conv
       
   290   (Conv.rewr_conv dist3 then_conv binop_conv set_conv unfold_distinct_conv)) ct
       
   291 end
       
   292 
       
   293 
       
   294 
       
   295 (* simpset *)
       
   296 
       
   297 local
       
   298   val antisym_le1 = mk_meta_eq @{thm order_class.antisym_conv}
       
   299   val antisym_le2 = mk_meta_eq @{thm linorder_class.antisym_conv2}
       
   300   val antisym_less1 = mk_meta_eq @{thm linorder_class.antisym_conv1}
       
   301   val antisym_less2 = mk_meta_eq @{thm linorder_class.antisym_conv3}
       
   302 
       
   303   fun eq_prop t thm = HOLogic.mk_Trueprop t aconv Thm.prop_of thm
       
   304   fun dest_binop ((c as Const _) $ t $ u) = (c, t, u)
       
   305     | dest_binop t = raise TERM ("dest_binop", [t])
       
   306 
       
   307   fun prove_antisym_le ctxt t =
       
   308     let
       
   309       val (le, r, s) = dest_binop t
       
   310       val less = Const (@{const_name less}, Term.fastype_of le)
       
   311       val prems = Simplifier.prems_of ctxt
       
   312     in
       
   313       (case find_first (eq_prop (le $ s $ r)) prems of
       
   314         NONE =>
       
   315           find_first (eq_prop (HOLogic.mk_not (less $ r $ s))) prems
       
   316           |> Option.map (fn thm => thm RS antisym_less1)
       
   317       | SOME thm => SOME (thm RS antisym_le1))
       
   318     end
       
   319     handle THM _ => NONE
       
   320 
       
   321   fun prove_antisym_less ctxt t =
       
   322     let
       
   323       val (less, r, s) = dest_binop (HOLogic.dest_not t)
       
   324       val le = Const (@{const_name less_eq}, Term.fastype_of less)
       
   325       val prems = Simplifier.prems_of ctxt
       
   326     in
       
   327       (case find_first (eq_prop (le $ r $ s)) prems of
       
   328         NONE =>
       
   329           find_first (eq_prop (HOLogic.mk_not (less $ s $ r))) prems
       
   330           |> Option.map (fn thm => thm RS antisym_less2)
       
   331       | SOME thm => SOME (thm RS antisym_le2))
       
   332   end
       
   333   handle THM _ => NONE
       
   334 
       
   335   val basic_simpset =
       
   336     simpset_of (put_simpset HOL_ss @{context}
       
   337       addsimps @{thms field_simps}
       
   338       addsimps [@{thm times_divide_eq_right}, @{thm times_divide_eq_left}]
       
   339       addsimps @{thms arith_special} addsimps @{thms arith_simps}
       
   340       addsimps @{thms rel_simps}
       
   341       addsimps @{thms array_rules}
       
   342       addsimps @{thms term_true_def} addsimps @{thms term_false_def}
       
   343       addsimps @{thms z3div_def} addsimps @{thms z3mod_def}
       
   344       addsimprocs [@{simproc binary_int_div}, @{simproc binary_int_mod}]
       
   345       addsimprocs [
       
   346         Simplifier.simproc_global @{theory} "fast_int_arith" [
       
   347           "(m::int) < n", "(m::int) <= n", "(m::int) = n"] Lin_Arith.simproc,
       
   348         Simplifier.simproc_global @{theory} "antisym_le" ["(x::'a::order) <= y"]
       
   349           prove_antisym_le,
       
   350         Simplifier.simproc_global @{theory} "antisym_less" ["~ (x::'a::linorder) < y"]
       
   351           prove_antisym_less])
       
   352 
       
   353   structure Simpset = Generic_Data
       
   354   (
       
   355     type T = simpset
       
   356     val empty = basic_simpset
       
   357     val extend = I
       
   358     val merge = Simplifier.merge_ss
       
   359   )
       
   360 in
       
   361 
       
   362 fun add_simproc simproc context =
       
   363   Simpset.map (simpset_map (Context.proof_of context)
       
   364     (fn ctxt => ctxt addsimprocs [simproc])) context
       
   365 
       
   366 fun make_simpset ctxt rules =
       
   367   simpset_of (put_simpset (Simpset.get (Context.Proof ctxt)) ctxt addsimps rules)
       
   368 
       
   369 end
       
   370 
       
   371 end