src/Provers/Arith/fast_lin_arith.ML
changeset 20433 55471f940e5c
parent 20280 ad9fbbd01535
child 21109 f8f89be75e81
--- a/src/Provers/Arith/fast_lin_arith.ML	Wed Aug 30 03:19:08 2006 +0200
+++ b/src/Provers/Arith/fast_lin_arith.ML	Wed Aug 30 03:30:09 2006 +0200
@@ -590,7 +590,7 @@
 (*        failure as soon as a case could not be refuted; i.e. delay further *)
 (*        splitting until after a refutation for other cases has been found. *)
 
-fun split_items sg (Ts : typ list, terms : term list) :
+fun split_items sg (do_pre : bool) (Ts : typ list, terms : term list) :
                 (typ list * (LA_Data.decompT * int) list) list =
 let
 (*
@@ -635,7 +635,7 @@
   val result = (Ts, terms)
     |> (* user-defined preprocessing of the subgoal *)
        (* (typ list * term list) list *)
-       LA_Data.pre_decomp sg
+       (if do_pre then LA_Data.pre_decomp sg else Library.single)
     |> (* compute the internal encoding of (in-)equalities *)
        (* (typ list * (LA_Data.decompT option * bool) list) list *)
        map (apsnd (map (fn t => (LA_Data.decomp sg t, LA_Data.domain_is_nat t))))
@@ -712,10 +712,10 @@
 in refute end;
 
 fun refute (sg : theory) (params : (string * Term.typ) list) (show_ex : bool)
-           (terms : term list) : injust list option =
-  refutes sg params show_ex (split_items sg (map snd params, terms)) [];
+           (do_pre : bool) (terms : term list) : injust list option =
+  refutes sg params show_ex (split_items sg do_pre (map snd params, terms)) [];
 
-fun count (P : 'a -> bool) (xs : 'a list) : int = length (List.filter P xs);
+fun count P xs = length (List.filter P xs);
 
 (* The limit on the number of ~= allowed.
    Because each ~= is split into two cases, this can lead to an explosion.
@@ -723,7 +723,7 @@
 val fast_arith_neq_limit = ref 9;
 
 fun prove (sg : theory) (params : (string * Term.typ) list) (show_ex : bool)
-          (Hs : term list) (concl : term) : injust list option =
+          (do_pre : bool) (Hs : term list) (concl : term) : injust list option =
   let
     (* append the negated conclusion to 'Hs' -- this corresponds to     *)
     (* 'DETERM (resolve_tac [LA_Logic.notI, LA_Logic.ccontr] i)' at the *)
@@ -739,7 +739,7 @@
                    string_of_int (!fast_arith_neq_limit) ^ ")");
       NONE
     ) else
-      refute sg params show_ex Hs'
+      refute sg params show_ex do_pre Hs'
   end;
 
 fun refute_tac (ss : simpset) (i : int, justs : injust list) : tactic =
@@ -773,7 +773,7 @@
       val Hs     = Logic.strip_assums_hyp A
       val concl  = Logic.strip_assums_concl A
   in trace_thm ("Trying to refute subgoal " ^ string_of_int i) st;
-     case prove (Thm.sign_of_thm st) params show_ex Hs concl of
+     case prove (Thm.sign_of_thm st) params show_ex true Hs concl of
        NONE => (trace_msg "Refutation failed."; no_tac)
      | SOME js => (trace_msg "Refutation succeeded."; refute_tac ss (i, js))
   end) i st;
@@ -789,11 +789,59 @@
 
 (** Forward proof from theorems **)
 
+(* More tricky code. Needs to arrange the proofs of the multiple cases (due
+to splits of ~= premises) such that it coincides with the order of the cases
+generated by function split_items. *)
+
+datatype splittree = Tip of thm list
+                   | Spl of thm * cterm * splittree * cterm * splittree;
+
+(* "(ct1 ==> ?R) ==> (ct2 ==> ?R) ==> ?R" is taken to (ct1, ct2) *)
+
+fun extract (imp : cterm) : cterm * cterm =
+let val (Il, r)    = Thm.dest_comb imp
+    val (_, imp1)  = Thm.dest_comb Il
+    val (Ict1, _)  = Thm.dest_comb imp1
+    val (_, ct1)   = Thm.dest_comb Ict1
+    val (Ir, _)    = Thm.dest_comb r
+    val (_, Ict2r) = Thm.dest_comb Ir
+    val (Ict2, _)  = Thm.dest_comb Ict2r
+    val (_, ct2)   = Thm.dest_comb Ict2
+in (ct1, ct2) end;
+
+fun splitasms (sg : theory) (asms : thm list) : splittree =
+let val {neqE, ...} = Data.get sg
+    fun elim_neq (asms', []) = Tip (rev asms')
+      | elim_neq (asms', asm::asms) =
+      (case get_first (fn th => SOME (asm COMP th) handle THM _ => NONE) neqE of
+        SOME spl =>
+          let val (ct1, ct2) = extract (cprop_of spl)
+              val thm1 = assume ct1
+              val thm2 = assume ct2
+          in Spl (spl, ct1, elim_neq (asms', asms@[thm1]), ct2, elim_neq (asms', asms@[thm2]))
+          end
+      | NONE => elim_neq (asm::asms', asms))
+in elim_neq ([], asms) end;
+
+fun fwdproof (ctxt : theory * simpset) (Tip asms : splittree) (j::js : injust list) =
+    (mkthm ctxt asms j, js)
+  | fwdproof ctxt (Spl (thm, ct1, tree1, ct2, tree2)) js =
+    let val (thm1, js1) = fwdproof ctxt tree1 js
+        val (thm2, js2) = fwdproof ctxt tree2 js1
+        val thm1' = implies_intr ct1 thm1
+        val thm2' = implies_intr ct2 thm2
+    in (thm2' COMP (thm1' COMP thm), js2) end;
+(* needs handle THM _ => NONE ? *)
+
 fun prover (ctxt as (sg, ss)) thms (Tconcl : term) (js : injust list) (pos : bool) : thm option =
 let
+(* vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv *)
+(* Use this code instead if lin_arith_prover calls prove with do_pre set to true *)
+(* but beware: this can be a significant performance issue.                      *)
     (* There is no "forward version" of 'pre_tac'.  Therefore we combine the     *)
     (* available theorems into a single proof state and perform "backward proof" *)
     (* using 'refute_tac'.                                                       *)
+(*
     val Hs    = map prop_of thms
     val Prop  = fold (curry Logic.mk_implies) (rev Hs) Tconcl
     val cProp = cterm_of sg Prop
@@ -802,6 +850,15 @@
                   |> Seq.hd
                   |> Goal.finish
                   |> fold (fn thA => fn thAB => implies_elim thAB thA) thms
+*)
+(* ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ *)
+    val nTconcl       = LA_Logic.neg_prop Tconcl
+    val cnTconcl      = cterm_of sg nTconcl
+    val nTconclthm    = assume cnTconcl
+    val tree          = splitasms sg (thms @ [nTconclthm])
+    val (Falsethm, _) = fwdproof ctxt tree js
+    val contr         = if pos then LA_Logic.ccontr else LA_Logic.notI
+    val concl         = implies_intr cnTconcl Falsethm COMP contr
 in SOME (trace_thm "Proved by lin. arith. prover:"
           (LA_Logic.mk_Eq concl)) end
 (* in case concl contains ?-var, which makes assume fail: *)
@@ -823,10 +880,10 @@
     val _ = map (trace_thm "thms:") thms
     val _ = trace_msg ("concl:" ^ Sign.string_of_term sg concl)
 *)
-in case prove sg [] false Hs Tconcl of (* concl provable? *)
+in case prove sg [] false false Hs Tconcl of (* concl provable? *)
      SOME js => prover (sg, ss) thms Tconcl js true
    | NONE => let val nTconcl = LA_Logic.neg_prop Tconcl
-          in case prove sg [] false Hs nTconcl of (* ~concl provable? *)
+          in case prove sg [] false false Hs nTconcl of (* ~concl provable? *)
                SOME js => prover (sg, ss) thms nTconcl js false
              | NONE => NONE
           end