src/HOL/Tools/SMT/smt_replay.ML
changeset 69204 d5ab1636660b
child 69593 3dda49e08b9d
equal deleted inserted replaced
69203:a5c0d61ce5db 69204:d5ab1636660b
       
     1 (*  Title:      HOL/Tools/SMT/smt_replay.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3     Author:     Jasmin Blanchette, TU Muenchen
       
     4     Author:     Mathias Fleury, MPII
       
     5 
       
     6 Shared library for parsing and replay.
       
     7 *)
       
     8 
       
     9 signature SMT_REPLAY =
       
    10 sig
       
    11   (*theorem nets*)
       
    12   val thm_net_of: ('a -> thm) -> 'a list -> 'a Net.net
       
    13   val net_instances: (int * thm) Net.net -> cterm -> (int * thm) list
       
    14 
       
    15   (*proof combinators*)
       
    16   val under_assumption: (thm -> thm) -> cterm -> thm
       
    17   val discharge: thm -> thm -> thm
       
    18 
       
    19   (*a faster COMP*)
       
    20   type compose_data = cterm list * (cterm -> cterm list) * thm
       
    21   val precompose: (cterm -> cterm list) -> thm -> compose_data
       
    22   val precompose2: (cterm -> cterm * cterm) -> thm -> compose_data
       
    23   val compose: compose_data -> thm -> thm
       
    24 
       
    25   (*simpset*)
       
    26   val add_simproc: Simplifier.simproc -> Context.generic -> Context.generic
       
    27   val make_simpset: Proof.context -> thm list -> simpset
       
    28 
       
    29   (*assertion*)
       
    30   val add_asserted:  ('a * ('b * thm) -> 'c -> 'c) ->
       
    31     'c -> ('d -> 'a * 'e * term * 'b) -> ('e -> bool) -> Proof.context -> thm list ->
       
    32     (int * thm) list -> 'd list -> Proof.context ->
       
    33     ((int * ('a * thm)) list * thm list) * (Proof.context * 'c)
       
    34   
       
    35   (*statistics*)
       
    36   val pretty_statistics: string -> int -> int list Symtab.table -> Pretty.T
       
    37   val intermediate_statistics: Proof.context -> Timing.start -> int -> int -> unit
       
    38 
       
    39   (*theorem transformation*)
       
    40   val varify: Proof.context -> thm -> thm
       
    41   val params_of: term -> (string * typ) list
       
    42 end;
       
    43 
       
    44 structure SMT_Replay : SMT_REPLAY =
       
    45 struct
       
    46 
       
    47 (* theorem nets *)
       
    48 
       
    49 fun thm_net_of f xthms =
       
    50   let fun insert xthm = Net.insert_term (K false) (Thm.prop_of (f xthm), xthm)
       
    51   in fold insert xthms Net.empty end
       
    52 
       
    53 fun maybe_instantiate ct thm =
       
    54   try Thm.first_order_match (Thm.cprop_of thm, ct)
       
    55   |> Option.map (fn inst => Thm.instantiate inst thm)
       
    56 
       
    57 local
       
    58   fun instances_from_net match f net ct =
       
    59     let
       
    60       val lookup = if match then Net.match_term else Net.unify_term
       
    61       val xthms = lookup net (Thm.term_of ct)
       
    62       fun select ct = map_filter (f (maybe_instantiate ct)) xthms
       
    63       fun select' ct =
       
    64         let val thm = Thm.trivial ct
       
    65         in map_filter (f (try (fn rule => rule COMP thm))) xthms end
       
    66     in (case select ct of [] => select' ct | xthms' => xthms') end
       
    67 in
       
    68 
       
    69 fun net_instances net =
       
    70   instances_from_net false (fn f => fn (i, thm) => Option.map (pair i) (f thm))
       
    71     net
       
    72 
       
    73 end
       
    74 
       
    75 
       
    76 (* proof combinators *)
       
    77 
       
    78 fun under_assumption f ct =
       
    79   let val ct' = SMT_Util.mk_cprop ct in Thm.implies_intr ct' (f (Thm.assume ct')) end
       
    80 
       
    81 fun discharge p pq = Thm.implies_elim pq p
       
    82 
       
    83 
       
    84 (* a faster COMP *)
       
    85 
       
    86 type compose_data = cterm list * (cterm -> cterm list) * thm
       
    87 
       
    88 fun list2 (x, y) = [x, y]
       
    89 
       
    90 fun precompose f rule : compose_data = (f (Thm.cprem_of rule 1), f, rule)
       
    91 fun precompose2 f rule : compose_data = precompose (list2 o f) rule
       
    92 
       
    93 fun compose (cvs, f, rule) thm =
       
    94   discharge thm
       
    95     (Thm.instantiate ([], map (dest_Var o Thm.term_of) cvs ~~ f (Thm.cprop_of thm)) rule)
       
    96 
       
    97 
       
    98 (* simpset *)
       
    99 
       
   100 local
       
   101   val antisym_le1 = mk_meta_eq @{thm order_class.antisym_conv}
       
   102   val antisym_le2 = mk_meta_eq @{thm linorder_class.antisym_conv2}
       
   103   val antisym_less1 = mk_meta_eq @{thm linorder_class.antisym_conv1}
       
   104   val antisym_less2 = mk_meta_eq @{thm linorder_class.antisym_conv3}
       
   105 
       
   106   fun eq_prop t thm = HOLogic.mk_Trueprop t aconv Thm.prop_of thm
       
   107   fun dest_binop ((c as Const _) $ t $ u) = (c, t, u)
       
   108     | dest_binop t = raise TERM ("dest_binop", [t])
       
   109 
       
   110   fun prove_antisym_le ctxt ct =
       
   111     let
       
   112       val (le, r, s) = dest_binop (Thm.term_of ct)
       
   113       val less = Const (@{const_name less}, Term.fastype_of le)
       
   114       val prems = Simplifier.prems_of ctxt
       
   115     in
       
   116       (case find_first (eq_prop (le $ s $ r)) prems of
       
   117         NONE =>
       
   118           find_first (eq_prop (HOLogic.mk_not (less $ r $ s))) prems
       
   119           |> Option.map (fn thm => thm RS antisym_less1)
       
   120       | SOME thm => SOME (thm RS antisym_le1))
       
   121     end
       
   122     handle THM _ => NONE
       
   123 
       
   124   fun prove_antisym_less ctxt ct =
       
   125     let
       
   126       val (less, r, s) = dest_binop (HOLogic.dest_not (Thm.term_of ct))
       
   127       val le = Const (@{const_name less_eq}, Term.fastype_of less)
       
   128       val prems = Simplifier.prems_of ctxt
       
   129     in
       
   130       (case find_first (eq_prop (le $ r $ s)) prems of
       
   131         NONE =>
       
   132           find_first (eq_prop (HOLogic.mk_not (less $ s $ r))) prems
       
   133           |> Option.map (fn thm => thm RS antisym_less2)
       
   134       | SOME thm => SOME (thm RS antisym_le2))
       
   135   end
       
   136   handle THM _ => NONE
       
   137 
       
   138   val basic_simpset =
       
   139     simpset_of (put_simpset HOL_ss @{context}
       
   140       addsimps @{thms field_simps times_divide_eq_right times_divide_eq_left arith_special
       
   141         arith_simps rel_simps array_rules z3div_def z3mod_def NO_MATCH_def}
       
   142       addsimprocs [@{simproc numeral_divmod},
       
   143         Simplifier.make_simproc @{context} "fast_int_arith"
       
   144          {lhss = [@{term "(m::int) < n"}, @{term "(m::int) \<le> n"}, @{term "(m::int) = n"}],
       
   145           proc = K Lin_Arith.simproc},
       
   146         Simplifier.make_simproc @{context} "antisym_le"
       
   147          {lhss = [@{term "(x::'a::order) \<le> y"}],
       
   148           proc = K prove_antisym_le},
       
   149         Simplifier.make_simproc @{context} "antisym_less"
       
   150          {lhss = [@{term "\<not> (x::'a::linorder) < y"}],
       
   151           proc = K prove_antisym_less}])
       
   152 
       
   153   structure Simpset = Generic_Data
       
   154   (
       
   155     type T = simpset
       
   156     val empty = basic_simpset
       
   157     val extend = I
       
   158     val merge = Simplifier.merge_ss
       
   159   )
       
   160 in
       
   161 
       
   162 fun add_simproc simproc context =
       
   163   Simpset.map (simpset_map (Context.proof_of context)
       
   164     (fn ctxt => ctxt addsimprocs [simproc])) context
       
   165 
       
   166 fun make_simpset ctxt rules =
       
   167   simpset_of (put_simpset (Simpset.get (Context.Proof ctxt)) ctxt addsimps rules)
       
   168 
       
   169 end
       
   170 
       
   171 local
       
   172   val remove_trigger = mk_meta_eq @{thm trigger_def}
       
   173   val remove_fun_app = mk_meta_eq @{thm fun_app_def}
       
   174 
       
   175   fun rewrite_conv _ [] = Conv.all_conv
       
   176     | rewrite_conv ctxt eqs = Simplifier.full_rewrite (empty_simpset ctxt addsimps eqs)
       
   177 
       
   178   val rewrite_true_rule = @{lemma "True \<equiv> \<not> False" by simp}
       
   179   val prep_rules = [@{thm Let_def}, remove_trigger, remove_fun_app, rewrite_true_rule]
       
   180 
       
   181   fun rewrite _ [] = I
       
   182     | rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)
       
   183 
       
   184   fun lookup_assm assms_net ct =
       
   185     net_instances assms_net ct
       
   186     |> map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
       
   187 in
       
   188 
       
   189 fun add_asserted tab_update tab_empty p_extract cond outer_ctxt rewrite_rules assms steps ctxt =
       
   190   let
       
   191     val eqs = map (rewrite ctxt [rewrite_true_rule]) rewrite_rules
       
   192     val eqs' = union Thm.eq_thm eqs prep_rules
       
   193 
       
   194     val assms_net =
       
   195       assms
       
   196       |> map (apsnd (rewrite ctxt eqs'))
       
   197       |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
       
   198       |> thm_net_of snd
       
   199 
       
   200     fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
       
   201 
       
   202     fun assume thm ctxt =
       
   203       let
       
   204         val ct = Thm.cprem_of thm 1
       
   205         val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
       
   206       in (thm' RS thm, ctxt') end
       
   207 
       
   208     fun add1 id fixes thm1 ((i, th), exact) ((iidths, thms), (ctxt, ptab)) =
       
   209       let
       
   210         val (thm, ctxt') = if exact then (Thm.implies_elim thm1 th, ctxt) else assume thm1 ctxt
       
   211         val thms' = if exact then thms else th :: thms
       
   212       in (((i, (id, th)) :: iidths, thms'), (ctxt', tab_update (id, (fixes, thm)) ptab)) end
       
   213 
       
   214     fun add step
       
   215         (cx as ((iidths, thms), (ctxt, ptab))) =
       
   216       let val (id, rule, concl, fixes) = p_extract step in
       
   217         if (*Z3_Proof.is_assumption rule andalso rule <> Z3_Proof.Hypothesis*) cond rule then
       
   218           let
       
   219             val ct = Thm.cterm_of ctxt concl
       
   220             val thm1 = Thm.trivial ct |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
       
   221             val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
       
   222           in
       
   223             (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
       
   224               [] =>
       
   225                 let val (thm, ctxt') = assume thm1 ctxt
       
   226                 in ((iidths, thms), (ctxt', tab_update (id, (fixes, thm)) ptab)) end
       
   227             | ithms => fold (add1 id fixes thm1) ithms cx)
       
   228           end
       
   229         else
       
   230           cx
       
   231       end
       
   232   in fold add steps (([], []), (ctxt, tab_empty)) end
       
   233 
       
   234 end
       
   235 
       
   236 fun params_of t = Term.strip_qnt_vars @{const_name Pure.all} t
       
   237 
       
   238 fun varify ctxt thm =
       
   239   let
       
   240     val maxidx = Thm.maxidx_of thm + 1
       
   241     val vs = params_of (Thm.prop_of thm)
       
   242     val vars = map_index (fn (i, (n, T)) => Var ((n, i + maxidx), T)) vs
       
   243   in Drule.forall_elim_list (map (Thm.cterm_of ctxt) vars) thm end
       
   244 
       
   245 fun intermediate_statistics ctxt start total =
       
   246   SMT_Config.statistics_msg ctxt (fn current =>
       
   247     "Reconstructed " ^ string_of_int current ^ " of " ^ string_of_int total ^ " steps in " ^
       
   248     string_of_int (Time.toMilliseconds (#elapsed (Timing.result start))) ^ " ms")
       
   249 
       
   250 fun pretty_statistics solver total stats =
       
   251   let
       
   252     fun mean_of is =
       
   253       let
       
   254         val len = length is
       
   255         val mid = len div 2
       
   256       in if len mod 2 = 0 then (nth is (mid - 1) + nth is mid) div 2 else nth is mid end
       
   257     fun pretty_item name p = Pretty.item (Pretty.separate ":" [Pretty.str name, p])
       
   258     fun pretty (name, milliseconds) = pretty_item name (Pretty.block (Pretty.separate "," [
       
   259       Pretty.str (string_of_int (length milliseconds) ^ " occurrences") ,
       
   260       Pretty.str (string_of_int (mean_of milliseconds) ^ " ms mean time"),
       
   261       Pretty.str (string_of_int (fold Integer.max milliseconds 0) ^ " ms maximum time"),
       
   262       Pretty.str (string_of_int (fold Integer.add milliseconds 0) ^ " ms total time")]))
       
   263   in
       
   264     Pretty.big_list (solver ^ " proof reconstruction statistics:") (
       
   265       pretty_item "total time" (Pretty.str (string_of_int total ^ " ms")) ::
       
   266       map pretty (Symtab.dest stats))
       
   267   end
       
   268 
       
   269 end;