src/HOL/SMT/Tools/z3_proof_reconstruction.ML
changeset 36896 c030819254d3
parent 36895 a96f9793d9c5
child 36897 6d1ecdb81ff0
equal deleted inserted replaced
36895:a96f9793d9c5 36896:c030819254d3
     5 *)
     5 *)
     6 
     6 
     7 signature Z3_PROOF_RECONSTRUCTION =
     7 signature Z3_PROOF_RECONSTRUCTION =
     8 sig
     8 sig
     9   val trace_assms: bool Config.T
     9   val trace_assms: bool Config.T
    10   val reconstruct: Proof.context -> SMT_Translate.recon -> string list -> thm
    10   val reconstruct: string list * SMT_Translate.recon -> Proof.context ->
       
    11     thm * Proof.context
    11   val setup: theory -> theory
    12   val setup: theory -> theory
    12 end
    13 end
    13 
    14 
    14 structure Z3_Proof_Reconstruction: Z3_PROOF_RECONSTRUCTION =
    15 structure Z3_Proof_Reconstruction: Z3_PROOF_RECONSTRUCTION =
    15 struct
    16 struct
   116   | literals_of p = L.make_littab [thm_of p]
   117   | literals_of p = L.make_littab [thm_of p]
   117 
   118 
   118 
   119 
   119 (* proof representation *)
   120 (* proof representation *)
   120 
   121 
   121 datatype proof =
   122 datatype proof = Unproved of P.proof_step | Proved of theorem
   122   Unproved of P.proof_step |
       
   123   Sequent of { hyps: cterm list, thm: theorem }
       
   124 
   123 
   125 
   124 
   126 
   125 
   127 (** core proof rules **)
   126 (** core proof rules **)
   128 
   127 
   154         quote (Syntax.string_of_term ctxt (Thm.term_of ct))))
   153         quote (Syntax.string_of_term ctxt (Thm.term_of ct))))
   155 in
   154 in
   156 fun prepare_assms unfolds assms =
   155 fun prepare_assms unfolds assms =
   157   let
   156   let
   158     val unfolds' = rewrite_rules [L.rewrite_true] unfolds
   157     val unfolds' = rewrite_rules [L.rewrite_true] unfolds
   159     val assms' = rewrite_rules (unfolds' @ prep_rules) assms
   158     val assms' = rewrite_rules (union Thm.eq_thm unfolds' prep_rules) assms
   160   in (unfolds', T.thm_net_of assms') end
   159   in (unfolds', T.thm_net_of assms') end
   161 
   160 
   162 fun asserted _ NONE ct = Thm (Thm.assume ct)
   161 fun asserted _ NONE ct = Thm (Thm.assume ct)
   163   | asserted ctxt (SOME (unfolds, assms)) ct =
   162   | asserted ctxt (SOME (unfolds, assms)) ct =
   164       let val revert_conv = rewrite_conv ctxt unfolds
   163       let val revert_conv = rewrite_conv ctxt unfolds
   194     let
   193     let
   195       val lit = the (L.get_first_lit (is_sublit conj t) lits)
   194       val lit = the (L.get_first_lit (is_sublit conj t) lits)
   196       val ls = L.explode conj false false [t] lit
   195       val ls = L.explode conj false false [t] lit
   197       val lits' = fold L.insert_lit ls (L.delete_lit lit lits)
   196       val lits' = fold L.insert_lit ls (L.delete_lit lit lits)
   198 
   197 
   199       fun upd (Sequent {hyps, thm}) =
   198       fun upd (Proved thm) = Proved (Literals (thm_of thm, lits'))
   200             Sequent {hyps = hyps, thm = Literals (thm_of thm, lits')}
       
   201         | upd p = p
   199         | upd p = p
   202     in (the (L.lookup_lit lits' t), Inttab.map_entry idx upd ptab) end
   200     in (the (L.lookup_lit lits' t), Inttab.map_entry idx upd ptab) end
   203 
   201 
   204   fun lit_elim conj (p, idx) ct ptab =
   202   fun lit_elim conj (p, idx) ct ptab =
   205     let val lits = literals_of p
   203     let val lits = literals_of p
   347   fun apply_rule ct =
   345   fun apply_rule ct =
   348     (case get_first (try (inst_rule ct)) intro_rules of
   346     (case get_first (try (inst_rule ct)) intro_rules of
   349       SOME thm => thm
   347       SOME thm => thm
   350     | NONE => raise CTERM ("intro_def", [ct]))
   348     | NONE => raise CTERM ("intro_def", [ct]))
   351 in
   349 in
   352 fun intro_def ct = apsnd Thm (T.make_hyp_def (apply_rule ct))
   350 fun intro_def ct = T.make_hyp_def (apply_rule ct) #>> Thm
   353 
   351 
   354 fun apply_def thm =
   352 fun apply_def thm =
   355   get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
   353   get_first (try (fn rule => MetaEq (thm COMP rule))) apply_rules
   356   |> the_default (Thm thm)
   354   |> the_default (Thm thm)
   357 end
   355 end
   588     |> Conv.fconv_rule (Thm.beta_conversion true)
   586     |> Conv.fconv_rule (Thm.beta_conversion true)
   589 
   587 
   590   fun kind (Const (@{const_name Ex}, _) $ _) = (sk_ex_rule, I, I)
   588   fun kind (Const (@{const_name Ex}, _) $ _) = (sk_ex_rule, I, I)
   591     | kind (@{term Not} $ (Const (@{const_name All}, _) $ _)) =
   589     | kind (@{term Not} $ (Const (@{const_name All}, _) $ _)) =
   592         (sk_all_rule, Thm.dest_arg, Thm.capply @{cterm Not})
   590         (sk_all_rule, Thm.dest_arg, Thm.capply @{cterm Not})
   593     | kind _ = z3_exn "skolemize: no quantifier"
   591     | kind t = raise TERM ("skolemize", [t])
   594 
   592 
   595   fun dest_abs_type (Abs (_, T, _)) = T
   593   fun dest_abs_type (Abs (_, T, _)) = T
   596     | dest_abs_type t = raise TERM ("dest_abs_type", [t])
   594     | dest_abs_type t = raise TERM ("dest_abs_type", [t])
   597 
   595 
   598   fun bodies_of thy lhs rhs =
   596   fun bodies_of thy lhs rhs =
   612         end
   610         end
   613     in (rule, dest_body 1 [] lhs) end
   611     in (rule, dest_body 1 [] lhs) end
   614 
   612 
   615   fun transitive f thm = Thm.transitive thm (f (Thm.rhs_of thm))
   613   fun transitive f thm = Thm.transitive thm (f (Thm.rhs_of thm))
   616 
   614 
   617   fun sk_step (rule, elim) (cv, mct, cb) (is, thm) =
   615   fun sk_step (rule, elim) (cv, mct, cb) ((is, thm), ctxt) =
   618     (case mct of
   616     (case mct of
   619       SOME ct =>
   617       SOME ct =>
   620         T.make_hyp_def (inst_sk rule (Thm.instantiate_cterm ([], is) cb) ct)
   618         ctxt
   621         |> apsnd (pair ((cv, ct) :: is) o Thm.transitive thm)
   619         |> T.make_hyp_def (inst_sk rule (Thm.instantiate_cterm ([], is) cb) ct)
   622     | NONE => ([], (is, transitive (Conv.rewr_conv elim) thm)))
   620         |>> pair ((cv, ct) :: is) o Thm.transitive thm
   623 in
   621     | NONE => ((is, transitive (Conv.rewr_conv elim) thm), ctxt))
   624 fun skolemize ctxt ct =
   622 in
       
   623 fun skolemize ct ctxt =
   625   let
   624   let
   626     val (lhs, rhs) = Thm.dest_binop (Thm.dest_arg ct)
   625     val (lhs, rhs) = Thm.dest_binop (Thm.dest_arg ct)
   627     val (rule, (ctab, cbs)) = bodies_of (ProofContext.theory_of ctxt) lhs rhs
   626     val (rule, (ctab, cbs)) = bodies_of (ProofContext.theory_of ctxt) lhs rhs
   628     fun lookup_var (cv, cb) = (cv, AList.lookup (op aconvc) ctab cv, cb)
   627     fun lookup_var (cv, cb) = (cv, AList.lookup (op aconvc) ctab cv, cb)
   629   in
   628   in
   630     ([], Thm.reflexive lhs)
   629     (([], Thm.reflexive lhs), ctxt)
   631     |> fold_map (sk_step rule) (map lookup_var cbs)
   630     |> fold (sk_step rule) (map lookup_var cbs)
   632     |> apfst (rev o flat) o apsnd (MetaEq o snd)
   631     |>> MetaEq o snd
   633   end
   632   end
   634 end
   633 end
   635 
   634 
   636 
   635 
   637 
   636 
   700 
   699 
   701 local
   700 local
   702   fun count_rules ptab =
   701   fun count_rules ptab =
   703     let
   702     let
   704       fun count (_, Unproved _) (solved, total) = (solved, total + 1)
   703       fun count (_, Unproved _) (solved, total) = (solved, total + 1)
   705         | count (_, Sequent _) (solved, total) = (solved + 1, total + 1)
   704         | count (_, Proved _) (solved, total) = (solved + 1, total + 1)
   706     in Inttab.fold count ptab (0, 0) end
   705     in Inttab.fold count ptab (0, 0) end
   707 
   706 
   708   fun header idx r (solved, total) = 
   707   fun header idx r (solved, total) = 
   709     "Z3: #" ^ string_of_int idx ^ ": " ^ P.string_of_rule r ^ " (goal " ^
   708     "Z3: #" ^ string_of_int idx ^ ": " ^ P.string_of_rule r ^ " (goal " ^
   710     string_of_int (solved + 1) ^ " of " ^ string_of_int total ^ ")"
   709     string_of_int (solved + 1) ^ " of " ^ string_of_int total ^ ")"
   711 
   710 
   712   fun check ctxt idx r ps ct ((_, p), _) =
   711   fun check ctxt idx r ps ct p =
   713     let val thm = thm_of p |> tap (Thm.join_proofs o single)
   712     let val thm = thm_of p |> tap (Thm.join_proofs o single)
   714     in
   713     in
   715       if (Thm.cprop_of thm) aconvc ct then ()
   714       if (Thm.cprop_of thm) aconvc ct then ()
   716       else z3_exn (Pretty.string_of (Pretty.big_list ("proof step failed: " ^
   715       else z3_exn (Pretty.string_of (Pretty.big_list ("proof step failed: " ^
   717         quote (P.string_of_rule r) ^ " (#" ^ string_of_int idx ^ ")")
   716         quote (P.string_of_rule r) ^ " (#" ^ string_of_int idx ^ ")")
   718           (pretty_goal ctxt (map (thm_of o fst) ps) (Thm.prop_of thm) @
   717           (pretty_goal ctxt (map (thm_of o fst) ps) (Thm.prop_of thm) @
   719            [Pretty.block [Pretty.str "expected: ",
   718            [Pretty.block [Pretty.str "expected: ",
   720             Syntax.pretty_term ctxt (Thm.term_of ct)]])))
   719             Syntax.pretty_term ctxt (Thm.term_of ct)]])))
   721     end
   720     end
   722 in
   721 in
   723 fun trace_rule ctxt idx prove r ps ct ptab =
   722 fun trace_rule idx prove r ps ct (cxp as (ctxt, ptab)) =
   724   let
   723   let
   725     val _ = SMT_Solver.trace_msg ctxt (header idx r o count_rules) ptab
   724     val _ = SMT_Solver.trace_msg ctxt (header idx r o count_rules) ptab
   726     val result = prove r ps ct ptab
   725     val result as (p, cxp' as (ctxt', _)) = prove r ps ct cxp
   727     val _ = if not (Config.get ctxt SMT_Solver.trace) then ()
   726     val _ = if not (Config.get ctxt' SMT_Solver.trace) then ()
   728       else check ctxt idx r ps ct result
   727       else check ctxt' idx r ps ct p
   729   in result end
   728   in result end
   730 end
   729 end
   731 
   730 
   732 
   731 
   733 (* overall reconstruction procedure *)
   732 (* overall reconstruction procedure *)
   734 
   733 
   735 fun not_supported r =
   734 fun not_supported r =
   736   z3_exn ("proof rule not implemented: " ^ quote (P.string_of_rule r))
   735   raise Fail ("Z3: proof rule not implemented: " ^ quote (P.string_of_rule r))
   737 
   736 
   738 fun prove ctxt unfolds assms vars =
   737 fun prove ctxt unfolds assms vars =
   739   let
   738   let
   740     val assms' = Option.map (prepare_assms unfolds) assms
   739     val assms' = Option.map (prepare_assms unfolds) assms
   741     val simpset = T.make_simpset ctxt (Z3_Simps.get ctxt)
   740     val simpset = T.make_simpset ctxt (Z3_Simps.get ctxt)
   742 
   741 
   743     fun step r ps ct ptab =
   742     fun step r ps ct (cxp as (cx, ptab)) =
   744       (case (r, ps) of
   743       (case (r, ps) of
   745         (* core rules *)
   744         (* core rules *)
   746         (P.TrueAxiom, _) => (([], Thm L.true_thm), ptab)
   745         (P.TrueAxiom, _) => (Thm L.true_thm, cxp)
   747       | (P.Asserted, _) => (([], asserted ctxt assms' ct), ptab)
   746       | (P.Asserted, _) => (asserted cx assms' ct, cxp)
   748       | (P.Goal, _) => (([], asserted ctxt assms' ct), ptab)
   747       | (P.Goal, _) => (asserted cx assms' ct, cxp)
   749       | (P.ModusPonens, [(p, _), (q, _)]) => (([], mp q (thm_of p)), ptab)
   748       | (P.ModusPonens, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
   750       | (P.ModusPonensOeq, [(p, _), (q, _)]) => (([], mp q (thm_of p)), ptab)
   749       | (P.ModusPonensOeq, [(p, _), (q, _)]) => (mp q (thm_of p), cxp)
   751       | (P.AndElim, [(p, i)]) => apfst (pair []) (and_elim (p, i) ct ptab)
   750       | (P.AndElim, [(p, i)]) => and_elim (p, i) ct ptab ||> pair cx
   752       | (P.NotOrElim, [(p, i)]) => apfst (pair []) (not_or_elim (p, i) ct ptab)
   751       | (P.NotOrElim, [(p, i)]) => not_or_elim (p, i) ct ptab ||> pair cx
   753       | (P.Hypothesis, _) => (([], Thm (Thm.assume ct)), ptab)
   752       | (P.Hypothesis, _) => (Thm (Thm.assume ct), cxp)
   754       | (P.Lemma, [(p, _)]) => (([], lemma (thm_of p) ct), ptab)
   753       | (P.Lemma, [(p, _)]) => (lemma (thm_of p) ct, cxp)
   755       | (P.UnitResolution, (p, _) :: ps) =>
   754       | (P.UnitResolution, (p, _) :: ps) =>
   756           (([], unit_resolution (thm_of p) (map (thm_of o fst) ps) ct), ptab)
   755           (unit_resolution (thm_of p) (map (thm_of o fst) ps) ct, cxp)
   757       | (P.IffTrue, [(p, _)]) => (([], iff_true (thm_of p)), ptab)
   756       | (P.IffTrue, [(p, _)]) => (iff_true (thm_of p), cxp)
   758       | (P.IffFalse, [(p, _)]) => (([], iff_false (thm_of p)), ptab)
   757       | (P.IffFalse, [(p, _)]) => (iff_false (thm_of p), cxp)
   759       | (P.Distributivity, _) => (([], distributivity ctxt ct), ptab)
   758       | (P.Distributivity, _) => (distributivity cx ct, cxp)
   760       | (P.DefAxiom, _) => (([], def_axiom ctxt ct), ptab)
   759       | (P.DefAxiom, _) => (def_axiom cx ct, cxp)
   761       | (P.IntroDef, _) => (intro_def ct, ptab)
   760       | (P.IntroDef, _) => intro_def ct cx ||> rpair ptab
   762       | (P.ApplyDef, [(p, _)]) => (([], apply_def (thm_of p)), ptab)
   761       | (P.ApplyDef, [(p, _)]) => (apply_def (thm_of p), cxp)
   763       | (P.IffOeq, [(p, _)]) => (([], p), ptab)
   762       | (P.IffOeq, [(p, _)]) => (p, cxp)
   764       | (P.NnfPos, _) => (([], nnf ctxt vars (map fst ps) ct), ptab)
   763       | (P.NnfPos, _) => (nnf cx vars (map fst ps) ct, cxp)
   765       | (P.NnfNeg, _) => (([], nnf ctxt vars (map fst ps) ct), ptab)
   764       | (P.NnfNeg, _) => (nnf cx vars (map fst ps) ct, cxp)
   766 
   765 
   767         (* equality rules *)
   766         (* equality rules *)
   768       | (P.Reflexivity, _) => (([], refl ct), ptab)
   767       | (P.Reflexivity, _) => (refl ct, cxp)
   769       | (P.Symmetry, [(p, _)]) => (([], symm p), ptab)
   768       | (P.Symmetry, [(p, _)]) => (symm p, cxp)
   770       | (P.Transitivity, [(p, _), (q, _)]) => (([], trans p q), ptab)
   769       | (P.Transitivity, [(p, _), (q, _)]) => (trans p q, cxp)
   771       | (P.Monotonicity, _) => (([], monotonicity (map fst ps) ct), ptab)
   770       | (P.Monotonicity, _) => (monotonicity (map fst ps) ct, cxp)
   772       | (P.Commutativity, _) => (([], commutativity ct), ptab)
   771       | (P.Commutativity, _) => (commutativity ct, cxp)
   773 
   772 
   774         (* quantifier rules *)
   773         (* quantifier rules *)
   775       | (P.QuantIntro, [(p, _)]) => (([], quant_intro vars p ct), ptab)
   774       | (P.QuantIntro, [(p, _)]) => (quant_intro vars p ct, cxp)
   776       | (P.PullQuant, _) => (([], pull_quant ctxt ct), ptab)
   775       | (P.PullQuant, _) => (pull_quant cx ct, cxp)
   777       | (P.PushQuant, _) => (([], push_quant ctxt ct), ptab)
   776       | (P.PushQuant, _) => (push_quant cx ct, cxp)
   778       | (P.ElimUnusedVars, _) => (([], elim_unused_vars ctxt ct), ptab)
   777       | (P.ElimUnusedVars, _) => (elim_unused_vars cx ct, cxp)
   779       | (P.DestEqRes, _) => (([], dest_eq_res ctxt ct), ptab)
   778       | (P.DestEqRes, _) => (dest_eq_res cx ct, cxp)
   780       | (P.QuantInst, _) => (([], quant_inst ct), ptab)
   779       | (P.QuantInst, _) => (quant_inst ct, cxp)
   781       | (P.Skolemize, _) => (skolemize ctxt ct, ptab)
   780       | (P.Skolemize, _) => skolemize ct cx ||> rpair ptab
   782 
   781 
   783         (* theory rules *)
   782         (* theory rules *)
   784       | (P.ThLemma, _) =>
   783       | (P.ThLemma, _) =>
   785           (([], th_lemma ctxt simpset (map (thm_of o fst) ps) ct), ptab)
   784           (th_lemma cx simpset (map (thm_of o fst) ps) ct, cxp)
   786       | (P.Rewrite, _) => (([], rewrite ctxt simpset [] ct), ptab)
   785       | (P.Rewrite, _) => (rewrite cx simpset [] ct, cxp)
   787       | (P.RewriteStar, ps) =>
   786       | (P.RewriteStar, ps) =>
   788           (([], rewrite ctxt simpset (map fst ps) ct), ptab)
   787           (rewrite cx simpset (map fst ps) ct, cxp)
   789 
   788 
   790       | (P.NnfStar, _) => not_supported r
   789       | (P.NnfStar, _) => not_supported r
   791       | (P.CnfStar, _) => not_supported r
   790       | (P.CnfStar, _) => not_supported r
   792       | (P.TransitivityStar, _) => not_supported r
   791       | (P.TransitivityStar, _) => not_supported r
   793       | (P.PullQuantStar, _) => not_supported r
   792       | (P.PullQuantStar, _) => not_supported r
   794 
   793 
   795       | _ => z3_exn ("Proof rule " ^ quote (P.string_of_rule r) ^
   794       | _ => raise Fail ("Z3: proof rule " ^ quote (P.string_of_rule r) ^
   796          " has an unexpected number of arguments."))
   795          " has an unexpected number of arguments."))
   797 
   796 
   798     fun eq_hyp_def (ct, cu) = Thm.dest_arg1 ct aconvc Thm.dest_arg1 cu
   797     fun conclude idx rule prop (ps, cxp) =
   799       (* compare only the defined Frees, not the whole definitions *)
   798       trace_rule idx step rule ps prop cxp
   800 
   799       |-> (fn p => apsnd (Inttab.update (idx, Proved p)) #> pair p)
   801     fun conclude idx rule prop ((hypss, ps), ptab) =
   800 
   802       trace_rule ctxt idx step rule ps prop ptab
   801     fun lookup idx (cxp as (cx, ptab)) =
   803       |>> apfst (distinct eq_hyp_def o fold append hypss)
       
   804 
       
   805     fun add_sequent idx (hyps, thm) ptab =
       
   806       ((hyps, thm), Inttab.update (idx, Sequent {hyps=hyps, thm=thm}) ptab)
       
   807 
       
   808     fun lookup idx ptab =
       
   809       (case Inttab.lookup ptab idx of
   802       (case Inttab.lookup ptab idx of
   810         SOME (Unproved (P.Proof_Step {rule, prems, prop})) =>
   803         SOME (Unproved (P.Proof_Step {rule, prems, prop})) =>
   811           fold_map lookup prems ptab
   804           fold_map lookup prems cxp
   812           |>> split_list
   805           |>> map2 rpair prems
   813           |>> apsnd (fn ps => ps ~~ prems)
       
   814           |> conclude idx rule prop
   806           |> conclude idx rule prop
   815           |-> add_sequent idx
   807       | SOME (Proved p) => (p, cxp)
   816       | SOME (Sequent {hyps, thm}) => ((hyps, thm), ptab)
       
   817       | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
   808       | NONE => z3_exn ("unknown proof id: " ^ quote (string_of_int idx)))
   818 
   809 
   819     fun result (hyps, thm) =
   810     fun result (p, (cx, _)) = (thm_of p, cx)
   820       fold SMT_Normalize.discharge_definition hyps (thm_of thm)
       
   821   in
   811   in
   822     (fn (idx, ptab) => result (fst (lookup idx (Inttab.map Unproved ptab))))
   812     (fn (idx, ptab) => result (lookup idx (ctxt, Inttab.map Unproved ptab)))
   823   end
   813   end
   824 
   814 
   825 fun reconstruct ctxt {typs, terms, unfolds, assms} output =
   815 fun reconstruct (output, {typs, terms, unfolds, assms}) ctxt =
   826   P.parse ctxt typs terms output
   816   P.parse ctxt typs terms output
   827   |> (fn (idx, (ptab, vars, cx)) => prove cx unfolds assms vars (idx, ptab))
   817   |> (fn (idx, (ptab, vars, cx)) => prove cx unfolds assms vars (idx, ptab))
   828 
   818 
   829 val setup = trace_assms_setup #> z3_rules_setup #> Z3_Simps.setup
   819 val setup = trace_assms_setup #> z3_rules_setup #> Z3_Simps.setup
   830 
   820