simplify index handling
authorblanchet
Thu, 13 Mar 2014 14:48:20 +0100
changeset 56106 9cfea3ab002a
parent 56105 75dc126f5dcb
child 56107 2ec2d06b9424
simplify index handling
src/HOL/Tools/SMT2/smt2_normalize.ML
src/HOL/Tools/SMT2/smt2_solver.ML
--- a/src/HOL/Tools/SMT2/smt2_normalize.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_normalize.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -10,8 +10,7 @@
   val atomize_conv: Proof.context -> conv
   type extra_norm = Proof.context -> thm list * thm list -> thm list * thm list
   val add_extra_norm: SMT2_Util.class * extra_norm -> Context.generic -> Context.generic
-  val normalize: (int * (int option * thm)) list -> Proof.context ->
-    (int * thm) list * Proof.context
+  val normalize: Proof.context -> (int option * thm) list -> (int * thm) list
 end
 
 structure SMT2_Normalize: SMT2_NORMALIZE =
@@ -497,7 +496,7 @@
   let
     val (is, thms) = split_list ithms
     val (thms', extra_thms) = f thms
-  in (is ~~ thms') @ tag_list (fold Integer.max is 0 + 1) extra_thms end
+  in (is ~~ thms') @ map (pair ~1) extra_thms end
 
 fun unfold2 ctxt ithms =
   ithms
@@ -558,14 +557,14 @@
 
 end
 
-fun normalize iwthms ctxt =
-  iwthms
+fun normalize ctxt wthms =
+  wthms
+  |> map_index I
   |> gen_normalize ctxt
   |> unfold1 ctxt
   |> monomorph ctxt
   |> unfold2 ctxt
   |> apply_extra_norms ctxt
-  |> rpair ctxt
 
 val _ = Theory.setup (Context.theory_map (
   setup_atomize #>
--- a/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 14:48:20 2014 +0100
+++ b/src/HOL/Tools/SMT2/smt2_solver.ML	Thu Mar 13 14:48:20 2014 +0100
@@ -26,7 +26,7 @@
   val add_solver: solver_config -> theory -> theory
   val solver_name_of: Proof.context -> string
   val available_solvers_of: Proof.context -> string list
-  val apply_solver: Proof.context -> (int * (int option * thm)) list ->
+  val apply_solver: Proof.context -> (int option * thm) list ->
     ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm
   val default_max_relevant: Proof.context -> string -> int
 
@@ -155,6 +155,18 @@
     ((int * (int * thm)) list * Z3_New_Proof.z3_step list) * thm) option }
 
 
+(* check well-sortedness *)
+
+val has_topsort = Term.exists_type (Term.exists_subtype (fn
+    TFree (_, []) => true
+  | TVar (_, []) => true
+  | _ => false))
+
+(* top sorts cause problems with atomization *)
+fun check_topsort ctxt thm =
+  if has_topsort (Thm.prop_of thm) then (SMT2_Normalize.drop_fact_warning ctxt thm; TrueI) else thm
+
+
 (* registry *)
 
 type solver_info = {
@@ -227,36 +239,19 @@
   let val name = solver_name_of ctxt
   in (name, get_info ctxt name) end
 
-fun apply_solver ctxt0 iwthms =
+fun apply_solver ctxt wthms0 =
   let
-    val (ithms, ctxt) = SMT2_Normalize.normalize iwthms ctxt0
+    val wthms = map (apsnd (check_topsort ctxt)) wthms0
     val (name, {command, replay, ...}) = name_and_info_of ctxt
-  in replay ctxt (invoke name command ithms ctxt) end
+  in replay ctxt (invoke name command (SMT2_Normalize.normalize ctxt wthms) ctxt) end
 
 val default_max_relevant = #default_max_relevant oo get_info
 val supports_filter = #supports_filter o snd o name_and_info_of 
 
 
-(* check well-sortedness *)
-
-val has_topsort = Term.exists_type (Term.exists_subtype (fn
-    TFree (_, []) => true
-  | TVar (_, []) => true
-  | _ => false))
-
-(* without this test, we would run into problems when atomizing the rules: *)
-fun check_topsort ctxt thm =
-  if has_topsort (Thm.prop_of thm) then
-    (SMT2_Normalize.drop_fact_warning ctxt thm; TrueI)
-  else
-    thm
-
-fun check_topsorts ctxt iwthms = map (apsnd (apsnd (check_topsort ctxt))) iwthms
-
-
 (* filter *)
 
-val cnot = Thm.cterm_of @{theory} @{const Not}
+val no_id = ~1
 
 fun smt2_filter ctxt goal xwfacts i time_limit =
   let
@@ -267,67 +262,63 @@
       |> Config.put SMT2_Config.timeout (Time.toReal time_limit)
 
     val ({context=ctxt, prems, concl, ...}, _) = Subgoal.focus ctxt i goal
-    fun negate ct = Thm.dest_comb ct ||> Thm.apply cnot |-> Thm.apply
+    fun negate ct = Thm.dest_comb ct ||> Thm.apply @{cterm Not} |-> Thm.apply
     val cprop =
       (case try negate (Thm.rhs_of (SMT2_Normalize.atomize_conv ctxt concl)) of
         SOME ct => ct
       | NONE => raise SMT2_Failure.SMT (SMT2_Failure.Other_Failure "goal is not a HOL term"))
 
-    val iwconjecture = (~1, (NONE, Thm.assume cprop))
-    val iwprems = map (pair ~2 o pair NONE) prems
-    val iwfacts = map_index I (map snd xwfacts)
+    val wconjecture = (NONE, Thm.assume cprop)
+    val wprems = map (pair NONE) prems
+    val wfacts = map snd xwfacts
+    val wthms = wconjecture :: wprems @ wfacts
+    val iwthms = map_index I wthms
 
-    val n = length iwfacts
-    val xfacts = map (apsnd snd) xwfacts
+    val conjecture_i = 0
+    val facts_i = 1 + length wprems
   in
-    iwconjecture :: iwprems @ iwfacts
-    |> check_topsorts ctxt
+    wthms
     |> apply_solver ctxt
     |> fst
     |> (fn (iidths0, z3_proof) =>
-      let val iidths = if supports_filter ctxt then iidths0 else map (apsnd (apfst (K ~1))) iwfacts
+      let
+        val iidths = if supports_filter ctxt then iidths0 else map (apsnd (apfst (K no_id))) iwthms
       in
         {outcome = NONE, 
          conjecture_id =
-           the_default ~1 (Option.map fst (AList.lookup (op =) iidths (fst iwconjecture))),
-         helper_ids = map_filter (fn (i, (id, th)) => if i >= n then SOME (id, th) else NONE) iidths,
-         fact_ids = map_filter (fn (i, (id, _)) => try (apsnd (nth xfacts)) (id, i)) iidths,
+           the_default no_id (Option.map fst (AList.lookup (op =) iidths conjecture_i)),
+         helper_ids = map_filter (try (fn (~1, idth) => idth)) iidths,
+         fact_ids = map_filter (fn (i, (id, _)) =>
+           try (apsnd (apsnd snd o nth xwfacts)) (id, i - facts_i)) iidths,
          z3_proof = z3_proof}
       end)
   end
-  handle SMT2_Failure.SMT fail => {outcome = SOME fail, conjecture_id = ~1, helper_ids = [],
+  handle SMT2_Failure.SMT fail => {outcome = SOME fail, conjecture_id = no_id, helper_ids = [],
     fact_ids = [], z3_proof = []}
 
 
 (* SMT tactic *)
 
 local
-  fun trace_assumptions ctxt iwfacts iidths =
-    let
-      val wfacts =
-        iidths
-        |> map fst
-        |> filter (fn i => i >= 0)
-        |> map_filter (AList.lookup (op =) iwfacts)
-    in
+  fun trace_assumptions ctxt wfacts iidths =
+    let val used = map_filter (try (snd o nth wfacts) o fst) iidths in
       if Config.get ctxt SMT2_Config.trace_used_facts andalso length wfacts > 0 then
         tracing (Pretty.string_of (Pretty.big_list "SMT used facts:"
-          (map (Display.pretty_thm ctxt o snd) wfacts)))
+          (map (Display.pretty_thm ctxt) used)))
       else ()
     end
 
-  fun solve ctxt iwfacts =
-    iwfacts
-    |> check_topsorts ctxt
+  fun solve ctxt wfacts =
+    wfacts
     |> apply_solver ctxt
-    |>> apfst (trace_assumptions ctxt iwfacts)
+    |>> apfst (trace_assumptions ctxt wfacts)
     |> snd
 
   fun str_of ctxt fail =
     SMT2_Failure.string_of_failure ctxt fail
     |> prefix ("Solver " ^ SMT2_Config.solver_of ctxt ^ ": ")
 
-  fun safe_solve ctxt iwfacts = SOME (solve ctxt iwfacts)
+  fun safe_solve ctxt wfacts = SOME (solve ctxt wfacts)
     handle
       SMT2_Failure.SMT (fail as SMT2_Failure.Counterexample _) =>
         (SMT2_Config.verbose_msg ctxt (str_of ctxt) fail; NONE)
@@ -337,17 +328,14 @@
           "configuration option " ^ quote (Config.name_of SMT2_Config.timeout) ^ " might help)")
     | SMT2_Failure.SMT fail => error (str_of ctxt fail)
 
-  fun tag_rules thms = map_index (apsnd (pair NONE)) thms
-  fun tag_prems thms = map (pair ~1 o pair NONE) thms
-
   fun resolve (SOME thm) = rtac thm 1
     | resolve NONE = no_tac
 
   fun tac prove ctxt rules =
     CONVERSION (SMT2_Normalize.atomize_conv ctxt)
     THEN' rtac @{thm ccontr}
-    THEN' SUBPROOF (fn {context, prems, ...} =>
-      resolve (prove context (tag_rules rules @ tag_prems prems))) ctxt
+    THEN' SUBPROOF (fn {context = ctxt, prems, ...} =>
+      resolve (prove ctxt (map (pair NONE) (rules @ prems)))) ctxt
 in
 
 val smt2_tac = tac safe_solve