Added counter example generation.
authornipkow
Tue, 13 Aug 2002 21:55:58 +0200
changeset 13498 5330f1744817
parent 13497 defb74f6a5bc
child 13499 f95f5818f24f
Added counter example generation.
src/Provers/Arith/fast_lin_arith.ML
--- a/src/Provers/Arith/fast_lin_arith.ML	Tue Aug 13 21:54:23 2002 +0200
+++ b/src/Provers/Arith/fast_lin_arith.ML	Tue Aug 13 21:55:58 2002 +0200
@@ -78,7 +78,7 @@
                 -> theory -> theory
   val trace           : bool ref
   val lin_arith_prover: Sign.sg -> thm list -> term -> thm option
-  val     lin_arith_tac:             int -> tactic
+  val     lin_arith_tac:     bool -> int -> tactic
   val cut_lin_arith_tac: thm list -> int -> tactic
 end;
 
@@ -145,6 +145,102 @@
 
 datatype lineq = Lineq of int * lineq_type * int list * injust;
 
+fun el 0 (h::_) = h
+  | el n (_::t) = el (n - 1) t
+  | el _ _  = sys_error "el";
+
+(* ------------------------------------------------------------------------- *)
+(* Finding a (counter) example from the trace of a failed elimination        *)
+(* ------------------------------------------------------------------------- *)
+(* Examples are represented as rational numbers,                             *)
+(* although at the moment all examples are rounded to integers -             *)
+(* thus it does not yet work for type real.                                  *)
+(* Dont blame John Harrison for this code - it is entirely mine. TN          *)
+
+exception NoEx;
+exception NotYetImpl;
+
+fun elim_eqns(ineqs,Lineq(i,Le,cs,_)) = (i,cs)::ineqs
+  | elim_eqns(ineqs,Lineq(i,Eq,cs,_)) = (i,cs)::(~i,map ~ cs)::ineqs
+  | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = raise NotYetImpl;
+
+val rat0 = rat_of_int 0;
+
+(* PRE: ex[v] must be 0! *)
+fun eval (ex:rat list) v (a:int,cs:int list) =
+  let val rs = map rat_of_int cs
+      val rsum = foldl ratadd (rat0,map ratmul (rs ~~ ex))
+  in ratmul(ratadd(rat_of_int a,ratneg rsum), ratinv(el v rs)) end;
+
+(*
+fun ratge0(Rat(a,p,q)) = (p = 0 orelse a)
+*)
+fun ratge0 r = fst(rep_rat r) >= 0;
+fun ratle(r,s) = ratge0(ratadd(s,ratneg r))
+
+fun ratmin2(r,s) = if ratle(r,s) then r else s;
+fun ratmax2(r,s) = if ratle(r,s) then s else r;
+
+val ratmin = foldr1 ratmax2;
+val ratmax = foldr1 ratmax2;
+
+fun ratroundup r = let val (p,q) = rep_rat r
+                   in if q=1 then r else rat_of_int((p div q) + 1) end
+
+fun ratrounddown r = let val (p,q) = rep_rat r
+                     in if q=1 then r else rat_of_int((p div q) - 1) end
+
+fun choose2 d (lb,ub) =
+  if ratle(lb,rat0) andalso ratle(rat0,ub) then rat0 else
+  if not d then (if ratge0 lb then lb else ub) else
+  if ratge0 lb then let val lb' = ratroundup lb
+                    in if ratle(lb',ub) then lb' else raise NoEx end
+               else let val ub' = ratrounddown ub
+                    in if ratle(lb,ub') then ub' else raise NoEx end;
+
+fun findex1 discr (ex,(v,lineqs)) =
+  let val nz = filter (fn (Lineq(_,_,cs,_)) => el v cs <> 0) lineqs;
+      val ineqs = foldl elim_eqns ([],nz)
+      val (ge,le) = partition (fn (_,cs) => el v cs > 0) ineqs
+      val lb = ratmax(map (eval ex v) ge)
+      val ub = ratmin(map (eval ex v) le)
+  in nth_update (choose2 (nth_elem(v,discr)) (lb,ub)) (v,ex) end;
+
+fun findex discr = foldl (findex1 discr);
+
+fun elim1 v x =
+  map (fn (a,bs) => (ratadd(a,ratneg(ratmul(el v bs,x))),
+                     nth_update rat0 (v,bs)));
+
+fun single_var v cs = (filter_out (equal rat0) cs = [el v cs]);
+
+(* The base case:
+   all variables occur only with positive or only with negative coefficients *)
+fun pick_vars discr (ineqs,ex) =
+  let val nz = filter_out (forall (equal rat0) o snd) ineqs
+  in if null nz then ex
+     else let val v = find_index (not o equal rat0) (snd(hd nz))
+              val d = nth_elem(v,discr)
+              val sv = filter (single_var v o snd) nz
+              val minmax = if ratge0(el v (snd(hd nz)))
+                           then if d then ratroundup o ratmax else ratmax
+                           else if d then ratrounddown o ratmin else ratmin
+              val bnds = map (fn (a,bs) => ratmul(a,ratinv(el v bs))) sv
+              val x = minmax(rat0::bnds)
+              val ineqs' = elim1 v x nz
+              val ex' = nth_update x (v,ex)
+          in pick_vars discr (ineqs',ex') end
+  end;
+
+fun findex0 discr n lineqs =
+  let val ineqs = foldl elim_eqns ([],lineqs)
+      val rineqs = map (fn (a,cs) => (rat_of_int a, map rat_of_int cs)) ineqs
+  in pick_vars discr (rineqs,replicate n rat0) end;
+
+(* ------------------------------------------------------------------------- *)
+(* End of counter example finder. The actual decision procedure starts here. *)
+(* ------------------------------------------------------------------------- *)
+
 (* ------------------------------------------------------------------------- *)
 (* Calculate new (in)equality type after addition.                           *)
 (* ------------------------------------------------------------------------- *)
@@ -178,10 +274,6 @@
 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
 (* ------------------------------------------------------------------------- *)
 
-fun el 0 (h::_) = h
-  | el n (_::t) = el (n - 1) t
-  | el _ _  = sys_error "el";
-
 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   let val c1 = el v l1 and c2 = el v l2
       val m = lcm(abs c1,abs c2)
@@ -238,14 +330,19 @@
        commas(map string_of_int l)) ineqs))
   else ();
 
-fun elim ineqs =
+type history = (int * lineq list) list;
+datatype result = Success of injust | Failure of history;
+
+fun elim(ineqs,hist) =
   let val dummy = print_ineqs ineqs;
       val (triv,nontriv) = partition is_trivial ineqs in
   if not(null triv)
   then case Library.find_first is_answer triv of
-         None => elim nontriv | some => some
+         None => elim(nontriv,hist)
+       | Some(Lineq(_,_,_,j)) => Success j
   else
-  if null nontriv then None else
+  if null nontriv then Failure(hist)
+  else
   let val (eqs,noneqs) = partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   if not(null eqs) then
      let val clist = foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
@@ -254,11 +351,11 @@
          val c = hd sclist
          val (Some(eq as Lineq(_,_,ceq,_)),othereqs) =
                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
-         val v = find_index (fn k => k=c) ceq
+         val v = find_index_eq c ceq
          val (ioth,roth) = partition (fn (Lineq(_,_,l,_)) => el v l = 0)
                                      (othereqs @ noneqs)
          val others = map (elim_var v eq) roth @ ioth
-     in elim others end
+     in elim(others,(v,nontriv)::hist) end
   else
   let val lists = map (fn (Lineq(_,_,l,_)) => l) noneqs
       val numlist = 0 upto (length(hd lists) - 1)
@@ -266,11 +363,12 @@
       val blows = map calc_blowup coeffs
       val iblows = blows ~~ numlist
       val nziblows = filter (fn (i,_) => i<>0) iblows
-  in if null nziblows then None else
+  in if null nziblows then Failure((~1,nontriv)::hist)
+     else
      let val (c,v) = hd(sort (fn (x,y) => int_ord(fst(x),fst(y))) nziblows)
          val (no,yes) = partition (fn (Lineq(_,_,l,_)) => el v l = 0) ineqs
          val (pos,neg) = partition(fn (Lineq(_,_,l,_)) => el v l > 0) yes
-     in elim (no @ allpairs (elim_var v) pos neg) end
+     in elim(no @ allpairs (elim_var v) pos neg, (v,nontriv)::hist) end
   end
   end
   end;
@@ -285,7 +383,7 @@
 fun trace_msg msg = 
     if !trace then tracing msg else ();
 
-(* FIXME OPTIMIZE!!!!
+(* FIXME OPTIMIZE!!!! (partly done already)
    Addition/Multiplication need i*t representation rather than t+t+...
    Get rid of Mulitplied(2). For Nat LA_Data.number_of should return Suc^n
    because Numerals are not known early enough.
@@ -371,30 +469,63 @@
     fun mult(t,r) = let val (i,j) = rep_rat r in (t,i * (m div j)) end
 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
 
-fun mklineq atoms =
-  let val n = length atoms in
-    fn (item,k) =>
-    let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
-        val lhsa = map (coeff lhs) atoms
-        and rhsa = map (coeff rhs) atoms
-        val diff = map2 (op -) (rhsa,lhsa)
-        val c = i-j
-        val just = Asm k
-        fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j))
-    in case rel of
-        "<="   => lineq(c,Le,diff,just)
-       | "~<=" => if discrete
-                  then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
-                  else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
-       | "<"   => if discrete
-                  then lineq(c+1,Le,diff,LessD(just))
-                  else lineq(c,Lt,diff,just)
-       | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
-       | "="   => lineq(c,Eq,diff,just)
-       | _     => sys_error("mklineq" ^ rel)   
-    end
+fun mklineq n atoms =
+  fn (item,k) =>
+  let val (m,(lhs,i,rel,rhs,j,discrete)) = integ item
+      val lhsa = map (coeff lhs) atoms
+      and rhsa = map (coeff rhs) atoms
+      val diff = map2 (op -) (rhsa,lhsa)
+      val c = i-j
+      val just = Asm k
+      fun lineq(c,le,cs,j) = Lineq(c,le,cs, if m=1 then j else Multiplied2(m,j))
+  in case rel of
+      "<="   => lineq(c,Le,diff,just)
+     | "~<=" => if discrete
+                then lineq(1-c,Le,map (op ~) diff,NotLeDD(just))
+                else lineq(~c,Lt,map (op ~) diff,NotLeD(just))
+     | "<"   => if discrete
+                then lineq(c+1,Le,diff,LessD(just))
+                else lineq(c,Lt,diff,just)
+     | "~<"  => lineq(~c,Le,map (op~) diff,NotLessD(just))
+     | "="   => lineq(c,Eq,diff,just)
+     | _     => sys_error("mklineq" ^ rel)   
   end;
 
+(* ------------------------------------------------------------------------- *)
+(* Print (counter) example                                                   *)
+(* ------------------------------------------------------------------------- *)
+
+fun print_atom((a,d),r) =
+  let val (p,q) = rep_rat r
+      val s = if d then string_of_int p else
+              if p = 0 then "0"
+              else string_of_int p ^ "/" ^ string_of_int q
+  in a ^ " = " ^ s end;
+
+fun print_ex sds =
+  tracing o
+  apl("Counter example:\n",op ^) o
+  commas o
+  map print_atom o
+  apl(sds, op ~~);
+
+fun trace_ex(sg,params,atoms,discr,n,hist:history) =
+  if null hist then ()
+  else let val frees = map Free params;
+           fun s_of_t t = Sign.string_of_term sg (subst_bounds(frees,t));
+           val (v,lineqs) :: hist' = hist
+           val start = if v = ~1 then (findex0 discr n lineqs,hist')
+                       else (replicate n rat0,hist)
+       in print_ex ((map s_of_t atoms)~~discr) (findex discr start)
+          handle NoEx =>
+  (tracing "The decision procedure failed to prove your proposition\n\
+           \but could not construct a counter example either.\n\
+           \Probably the proposition is true but cannot be proved\n\
+           \by the incomplete decision procedure.")
+       end
+       handle NotYetImpl =>
+ tracing "No counter example: < on real not yet implemented.";
+
 fun mknat pTs ixs (atom,i) =
   if LA_Logic.is_nat(pTs,atom)
   then let val l = map (fn j => if j=i then 1 else 0) ixs
@@ -405,7 +536,7 @@
 in the subgoal. Numerical premises are coded as Some(tuple), non-numerical
 ones as None. Going through the premises, each numeric one is converted into
 a Lineq. The tricky bit is to convert ~= which is split into two cases < and
->. Thus mklineqss returns a list of equation systems. This may blow up if
+>. Thus split_items returns a list of equation systems. This may blow up if
 there are many ~=, but in practice it does not seem to happen. The really
 tricky bit is to arrange the order of the cases such that they coincide with
 the order in which the cases are in the end generated by the tactic that
@@ -422,49 +553,36 @@
           else elim_neq ((ineq,n) :: front) (n+1) ineqs
   in elim_neq [] 0 items end;
 
-fun mklineqss(pTs,items) =
-let
-  fun mklineqs(ineqs) =
-  let
-    fun add(ats,((lhs,_,_,rhs,_,_),_)) =
-             (map fst lhs) union ((map fst rhs) union ats)
-    val atoms = foldl add ([],ineqs)
-    val mkleq = mklineq atoms
-    val ixs = 0 upto (length(atoms)-1)
-    val iatoms = atoms ~~ ixs
-    val natlineqs = mapfilter (mknat pTs ixs) iatoms
-  in map mkleq ineqs @ natlineqs end
+fun add_atoms(ats,((lhs,_,_,rhs,_,_),_)) =
+    (map fst lhs) union ((map fst rhs) union ats)
 
-in map mklineqs (split_items items) end;
+fun add_datoms(dats,((lhs,_,_,rhs,_,d),_)) =
+    (map (pair d o fst) lhs) union ((map (pair d o fst) rhs) union dats)
+
+fun discr initems = map fst (foldl add_datoms ([],initems));
 
-(*
-fun mklineqss(pTs,items) =
-  let fun add(ats,None) = ats
-        | add(ats,Some(lhs,_,_,rhs,_,_)) =
-             (map fst lhs) union ((map fst rhs) union ats)
-      val atoms = foldl add ([],items)
-      val mkleq = mklineq atoms
-      val ixs = 0 upto (length(atoms)-1)
-      val iatoms = atoms ~~ ixs
-      val natlineqs = mapfilter (mknat pTs ixs) iatoms
- 
-      fun elim_neq front _ [] = [front]
-        | elim_neq front n (None::ineqs) = elim_neq front (n+1) ineqs
-        | elim_neq front n (Some(ineq as (l,i,rel,r,j,d))::ineqs) =
-          if rel = "~=" then elim_neq front n (ineqs @ [Some(l,i,"<",r,j,d)]) @
-                             elim_neq front n (ineqs @ [Some(r,j,"<",l,i,d)])
-          else elim_neq (mkleq(ineq,n) :: front) (n+1) ineqs
+fun refutes sg (pTs,params) ex =
+let
+  fun refute (initems::initemss) js =
+    let val atoms = foldl add_atoms ([],initems)
+        val n = length atoms
+        val mkleq = mklineq n atoms
+        val ixs = 0 upto (n-1)
+        val iatoms = atoms ~~ ixs
+        val natlineqs = mapfilter (mknat pTs ixs) iatoms
+        val ineqs = map mkleq initems @ natlineqs
+    in case elim(ineqs,[]) of
+         Success(j) =>
+           (trace_msg "Contradiction!"; refute initemss (js@[j]))
+       | Failure(hist) =>
+           (if not ex then ()
+            else trace_ex(sg,params,atoms,discr initems,n,hist);
+            None)
+    end
+    | refute [] js = Some js
+in refute end;
 
-  in elim_neq natlineqs 0 items end;
-*)
-
-fun elim_all (ineqs::ineqss) js =
-  (case elim ineqs of None => (trace_msg "No contradiction!"; None)
-   | Some(Lineq(_,_,_,j)) => (trace_msg "Contradiction!";
-                              elim_all ineqss (js@[j])))
-  | elim_all [] js = Some js
-
-fun refute(pTsitems) = elim_all (mklineqss pTsitems) [];
+fun refute sg ps ex items = refutes sg ps ex (split_items items) [];
 
 fun refute_tac(i,justs) =
   fn state =>
@@ -476,37 +594,38 @@
     end
     state;
 
-fun prove sg (pTs,Hs,concl) =
+fun prove sg ps ex Hs concl =
 let val Hitems = map (LA_Data.decomp sg) Hs
 in case LA_Data.decomp sg concl of
-     None => refute(pTs,Hitems@[None])
+     None => refute sg ps ex (Hitems@[None])
    | Some(citem as (r,i,rel,l,j,d)) =>
        let val neg::rel0 = explode rel
            val nrel = if neg = "~" then implode rel0 else "~"^rel
-       in refute(pTs, Hitems @ [Some(r,i,nrel,l,j,d)]) end
+       in refute sg ps ex (Hitems @ [Some(r,i,nrel,l,j,d)]) end
 end;
 
 (*
 Fast but very incomplete decider. Only premises and conclusions
 that are already (negated) (in)equations are taken into account.
 *)
-fun lin_arith_tac i st = SUBGOAL (fn (A,_) =>
-  let val pTs = rev(map snd (Logic.strip_params A))
+fun lin_arith_tac ex i st = SUBGOAL (fn (A,_) =>
+  let val params = rev(Logic.strip_params A)
+      val pTs = map snd params
       val Hs = Logic.strip_assums_hyp A
       val concl = Logic.strip_assums_concl A
   in trace_thm ("Trying to refute subgoal " ^ string_of_int i) st;
-     case prove (Thm.sign_of_thm st) (pTs,Hs,concl) of
+     case prove (Thm.sign_of_thm st) (pTs,params) ex Hs concl of
        None => (trace_msg "Refutation failed."; no_tac)
      | Some js => (trace_msg "Refutation succeeded."; refute_tac(i,js))
   end) i st;
 
-fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac i;
+fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac false i;
 
 (** Forward proof from theorems **)
 
 (* More tricky code. Needs to arrange the proofs of the multiple cases (due
 to splits of ~= premises) such that it coincides with the order of the cases
-generated by function mklineqss. *)
+generated by function split_items. *)
 
 datatype splittree = Tip of thm list
                    | Spl of thm * cterm * splittree * cterm * splittree
@@ -561,10 +680,10 @@
 fun lin_arith_prover sg thms concl =
 let val Hs = map (#prop o rep_thm) thms
     val Tconcl = LA_Logic.mk_Trueprop concl
-in case prove sg ([],Hs,Tconcl) of (* concl provable? *)
+in case prove sg ([],[]) false Hs Tconcl of (* concl provable? *)
      Some js => prover sg thms Tconcl js true
    | None => let val nTconcl = LA_Logic.neg_prop Tconcl
-          in case prove sg ([],Hs,nTconcl) of (* ~concl provable? *)
+          in case prove sg ([],[]) false Hs nTconcl of (* ~concl provable? *)
                Some js => prover sg thms nTconcl js false
              | None => None
           end