src/HOL/Tools/SMT2/z3_new_replay.ML
changeset 58061 3d060f43accb
parent 58060 835b5443b978
child 58062 f4d8987656b9
equal deleted inserted replaced
58060:835b5443b978 58061:3d060f43accb
     1 (*  Title:      HOL/Tools/SMT2/z3_new_replay.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3     Author:     Jasmin Blanchette, TU Muenchen
       
     4 
       
     5 Z3 proof parsing and replay.
       
     6 *)
       
     7 
       
     8 signature Z3_NEW_REPLAY =
       
     9 sig
       
    10   val parse_proof: Proof.context -> SMT2_Translate.replay_data ->
       
    11     ((string * ATP_Problem_Generate.stature) * thm) list -> term list -> term -> string list ->
       
    12     SMT2_Solver.parsed_proof
       
    13   val replay: Proof.context -> SMT2_Translate.replay_data -> string list -> thm
       
    14 end;
       
    15 
       
    16 structure Z3_New_Replay: Z3_NEW_REPLAY =
       
    17 struct
       
    18 
       
    19 fun params_of t = Term.strip_qnt_vars @{const_name Pure.all} t
       
    20 
       
    21 fun varify ctxt thm =
       
    22   let
       
    23     val maxidx = Thm.maxidx_of thm + 1
       
    24     val vs = params_of (Thm.prop_of thm)
       
    25     val vars = map_index (fn (i, (n, T)) => Var ((n, i + maxidx), T)) vs
       
    26   in Drule.forall_elim_list (map (SMT2_Util.certify ctxt) vars) thm end
       
    27 
       
    28 fun add_paramTs names t =
       
    29   fold2 (fn n => fn (_, T) => AList.update (op =) (n, T)) names (params_of t)
       
    30 
       
    31 fun new_fixes ctxt nTs =
       
    32   let
       
    33     val (ns, ctxt') = Variable.variant_fixes (replicate (length nTs) "") ctxt
       
    34     fun mk (n, T) n' = (n, SMT2_Util.certify ctxt' (Free (n', T)))
       
    35   in (ctxt', Symtab.make (map2 mk nTs ns)) end
       
    36 
       
    37 fun forall_elim_term ct (Const (@{const_name Pure.all}, _) $ (a as Abs _)) =
       
    38       Term.betapply (a, Thm.term_of ct)
       
    39   | forall_elim_term _ qt = raise TERM ("forall_elim'", [qt])
       
    40 
       
    41 fun apply_fixes elim env = fold (elim o the o Symtab.lookup env)
       
    42 
       
    43 val apply_fixes_prem = uncurry o apply_fixes Thm.forall_elim
       
    44 val apply_fixes_concl = apply_fixes forall_elim_term
       
    45 
       
    46 fun export_fixes env names = Drule.forall_intr_list (map (the o Symtab.lookup env) names)
       
    47 
       
    48 fun under_fixes f ctxt (prems, nthms) names concl =
       
    49   let
       
    50     val thms1 = map (varify ctxt) prems
       
    51     val (ctxt', env) =
       
    52       add_paramTs names concl []
       
    53       |> fold (uncurry add_paramTs o apsnd Thm.prop_of) nthms
       
    54       |> new_fixes ctxt
       
    55     val thms2 = map (apply_fixes_prem env) nthms
       
    56     val t = apply_fixes_concl env names concl
       
    57   in export_fixes env names (f ctxt' (thms1 @ thms2) t) end
       
    58 
       
    59 fun replay_thm ctxt assumed nthms
       
    60     (Z3_New_Proof.Z3_Step {id, rule, concl, fixes, is_fix_step, ...}) =
       
    61   if Z3_New_Proof.is_assumption rule then
       
    62     (case Inttab.lookup assumed id of
       
    63       SOME (_, thm) => thm
       
    64     | NONE => Thm.assume (SMT2_Util.certify ctxt concl))
       
    65   else
       
    66     under_fixes (Z3_New_Replay_Methods.method_for rule) ctxt
       
    67       (if is_fix_step then (map snd nthms, []) else ([], nthms)) fixes concl
       
    68 
       
    69 fun replay_step ctxt assumed (step as Z3_New_Proof.Z3_Step {id, prems, fixes, ...}) proofs =
       
    70   let val nthms = map (the o Inttab.lookup proofs) prems
       
    71   in Inttab.update (id, (fixes, replay_thm ctxt assumed nthms step)) proofs end
       
    72 
       
    73 local
       
    74   val remove_trigger = mk_meta_eq @{thm trigger_def}
       
    75   val remove_fun_app = mk_meta_eq @{thm fun_app_def}
       
    76 
       
    77   fun rewrite_conv _ [] = Conv.all_conv
       
    78     | rewrite_conv ctxt eqs = Simplifier.full_rewrite (empty_simpset ctxt addsimps eqs)
       
    79 
       
    80   val prep_rules = [@{thm Let_def}, remove_trigger, remove_fun_app,
       
    81     Z3_New_Replay_Literals.rewrite_true]
       
    82 
       
    83   fun rewrite _ [] = I
       
    84     | rewrite ctxt eqs = Conv.fconv_rule (rewrite_conv ctxt eqs)
       
    85 
       
    86   fun lookup_assm assms_net ct =
       
    87     Z3_New_Replay_Util.net_instances assms_net ct
       
    88     |> map (fn ithm as (_, thm) => (ithm, Thm.cprop_of thm aconvc ct))
       
    89 in
       
    90 
       
    91 fun add_asserted outer_ctxt rewrite_rules assms steps ctxt =
       
    92   let
       
    93     val eqs = map (rewrite ctxt [Z3_New_Replay_Literals.rewrite_true]) rewrite_rules
       
    94     val eqs' = union Thm.eq_thm eqs prep_rules
       
    95 
       
    96     val assms_net =
       
    97       assms
       
    98       |> map (apsnd (rewrite ctxt eqs'))
       
    99       |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
       
   100       |> Z3_New_Replay_Util.thm_net_of snd
       
   101 
       
   102     fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
       
   103 
       
   104     fun assume thm ctxt =
       
   105       let
       
   106         val ct = Thm.cprem_of thm 1
       
   107         val (thm', ctxt') = yield_singleton Assumption.add_assumes ct ctxt
       
   108       in (thm' RS thm, ctxt') end
       
   109 
       
   110     fun add1 id fixes thm1 ((i, th), exact) ((iidths, thms), (ctxt, ptab)) =
       
   111       let
       
   112         val (thm, ctxt') = if exact then (Thm.implies_elim thm1 th, ctxt) else assume thm1 ctxt
       
   113         val thms' = if exact then thms else th :: thms
       
   114       in (((i, (id, th)) :: iidths, thms'), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
       
   115 
       
   116     fun add (Z3_New_Proof.Z3_Step {id, rule, concl, fixes, ...})
       
   117         (cx as ((iidths, thms), (ctxt, ptab))) =
       
   118       if Z3_New_Proof.is_assumption rule andalso rule <> Z3_New_Proof.Hypothesis then
       
   119         let
       
   120           val ct = SMT2_Util.certify ctxt concl
       
   121           val thm1 = Thm.trivial ct |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
       
   122           val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
       
   123         in
       
   124           (case lookup_assm assms_net (Thm.cprem_of thm2 1) of
       
   125             [] =>
       
   126               let val (thm, ctxt') = assume thm1 ctxt
       
   127               in ((iidths, thms), (ctxt', Inttab.update (id, (fixes, thm)) ptab)) end
       
   128           | ithms => fold (add1 id fixes thm1) ithms cx)
       
   129         end
       
   130       else
       
   131         cx
       
   132   in fold add steps (([], []), (ctxt, Inttab.empty)) end
       
   133 
       
   134 end
       
   135 
       
   136 (* |- (EX x. P x) = P c     |- ~ (ALL x. P x) = ~ P c *)
       
   137 local
       
   138   val sk_rules = @{lemma
       
   139     "c = (SOME x. P x) ==> (EX x. P x) = P c"
       
   140     "c = (SOME x. ~ P x) ==> (~ (ALL x. P x)) = (~ P c)"
       
   141     by (metis someI_ex)+}
       
   142 in
       
   143 
       
   144 fun discharge_sk_tac i st =
       
   145   (rtac @{thm trans} i
       
   146    THEN resolve_tac sk_rules i
       
   147    THEN (rtac @{thm refl} ORELSE' discharge_sk_tac) (i+1)
       
   148    THEN rtac @{thm refl} i) st
       
   149 
       
   150 end
       
   151 
       
   152 fun make_discharge_rules rules = rules @ [@{thm allI}, @{thm refl},
       
   153   @{thm reflexive}, Z3_New_Replay_Literals.true_thm]
       
   154 
       
   155 val intro_def_rules = @{lemma
       
   156   "(~ P | P) & (P | ~ P)"
       
   157   "(P | ~ P) & (~ P | P)"
       
   158   by fast+}
       
   159 
       
   160 fun discharge_assms_tac rules =
       
   161   REPEAT (HEADGOAL (resolve_tac (intro_def_rules @ rules) ORELSE' SOLVED' discharge_sk_tac))
       
   162 
       
   163 fun discharge_assms ctxt rules thm =
       
   164   (if Thm.nprems_of thm = 0 then
       
   165      thm
       
   166    else
       
   167      (case Seq.pull (discharge_assms_tac rules thm) of
       
   168        SOME (thm', _) => thm'
       
   169      | NONE => raise THM ("failed to discharge premise", 1, [thm])))
       
   170   |> Goal.norm_result ctxt
       
   171 
       
   172 fun discharge rules outer_ctxt inner_ctxt =
       
   173   singleton (Proof_Context.export inner_ctxt outer_ctxt)
       
   174   #> discharge_assms outer_ctxt (make_discharge_rules rules)
       
   175 
       
   176 fun parse_proof outer_ctxt
       
   177     ({context = ctxt, typs, terms, ll_defs, rewrite_rules, assms} : SMT2_Translate.replay_data)
       
   178     xfacts prems concl output =
       
   179   let
       
   180     val (steps, ctxt2) = Z3_New_Proof.parse typs terms output ctxt
       
   181     val ((iidths, _), _) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
       
   182 
       
   183     fun id_of_index i = the_default ~1 (Option.map fst (AList.lookup (op =) iidths i))
       
   184 
       
   185     val conjecture_i = 0
       
   186     val prems_i = 1
       
   187     val facts_i = prems_i + length prems
       
   188 
       
   189     val conjecture_id = id_of_index conjecture_i
       
   190     val prem_ids = map id_of_index (prems_i upto facts_i - 1)
       
   191     val helper_ids = map_filter (try (fn (~1, idth) => idth)) iidths
       
   192     val fact_ids = map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i - facts_i)) iidths
       
   193     val fact_helper_ts =
       
   194       map (fn (_, th) => (ATP_Util.short_thm_name ctxt th, prop_of th)) helper_ids @
       
   195       map (fn (_, ((s, _), th)) => (s, prop_of th)) fact_ids
       
   196     val fact_helper_ids =
       
   197       map (apsnd (ATP_Util.short_thm_name ctxt)) helper_ids @ map (apsnd (fst o fst)) fact_ids
       
   198   in
       
   199     {outcome = NONE, fact_ids = fact_ids,
       
   200      atp_proof = fn () => Z3_New_Isar.atp_proof_of_z3_proof ctxt ll_defs rewrite_rules prems concl
       
   201        fact_helper_ts prem_ids conjecture_id fact_helper_ids steps}
       
   202   end
       
   203 
       
   204 fun replay outer_ctxt
       
   205     ({context = ctxt, typs, terms, rewrite_rules, assms, ...} : SMT2_Translate.replay_data) output =
       
   206   let
       
   207     val (steps, ctxt2) = Z3_New_Proof.parse typs terms output ctxt
       
   208     val ((_, rules), (ctxt3, assumed)) = add_asserted outer_ctxt rewrite_rules assms steps ctxt2
       
   209     val ctxt4 =
       
   210       ctxt3
       
   211       |> put_simpset (Z3_New_Replay_Util.make_simpset ctxt3 [])
       
   212       |> Config.put SAT.solver (Config.get ctxt3 SMT2_Config.sat_solver)
       
   213     val proofs = fold (replay_step ctxt4 assumed) steps assumed
       
   214     val (_, Z3_New_Proof.Z3_Step {id, ...}) = split_last steps
       
   215   in
       
   216     Inttab.lookup proofs id |> the |> snd |> discharge rules outer_ctxt ctxt4
       
   217   end
       
   218 
       
   219 end;