Version 1 of linear arithmetic for nat.
authornipkow
Mon, 04 Jan 1999 15:08:40 +0100
changeset 6056 b21813d1b701
parent 6055 fdf4638bf726
child 6057 395ea7617554
Version 1 of linear arithmetic for nat.
src/Provers/Arith/fast_lin_arith.ML
--- a/src/Provers/Arith/fast_lin_arith.ML	Mon Jan 04 15:07:47 1999 +0100
+++ b/src/Provers/Arith/fast_lin_arith.ML	Mon Jan 04 15:08:40 1999 +0100
@@ -28,6 +28,9 @@
   val decomp: term ->
              ((term * int)list * int * string * (term * int)list * int)option
   val simp: thm -> thm
+  val is_False: thm -> bool
+  val is_nat: typ list * term -> bool
+  val mk_nat_thm: Sign.sg -> term -> thm
 end;
 (*
 decomp(`x Rel y') should yield (p,i,Rel,q,j)
@@ -38,6 +41,9 @@
 simp must reduce contradictory <= to False.
    It should also cancel common summands to keep <= reduced;
    otherwise <= can grow to massive proportions.
+
+is_nat(parameter-types,t) =  t:nat
+mk_nat_thm(t) = "0 <= t"
 *)
 
 functor Fast_Lin_Arith(LA_Data:LIN_ARITH_DATA) =
@@ -45,11 +51,15 @@
 
 (*** A fast decision procedure ***)
 (*** Code ported from HOL Light ***)
-(* possible optimizations: eliminate eqns first; use (var,coeff) rep *)
+(* possible optimizations:
+   use (var,coeff) rep or vector rep  tp save space;
+   treat non-negative atoms separately rather than adding 0 <= atom
+*)
 
 datatype lineq_type = Eq | Le | Lt;
 
-datatype injust = Given of int
+datatype injust = Asm of int
+                | Nat of int (* index of atom *)
                 | Fwd of injust * thm
                 | Multiplied of int * injust
                 | Added of injust * injust;
@@ -148,8 +158,17 @@
         | extract xs []      = (None,xs)
   in extract [] end;
 
+(*
+fun print_ineqs ineqs =
+ writeln(cat_lines(""::map (fn Lineq(c,t,l,_) =>
+   string_of_int c ^
+   (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
+   commas(map string_of_int l)) ineqs));
+*)
+
 fun elim ineqs =
-  let val (triv,nontriv) = partition is_trivial ineqs in
+  let (*val dummy = print_ineqs ineqs;*)
+      val (triv,nontriv) = partition is_trivial ineqs in
   if not(null triv)
   then case find_first is_answer triv of
          None => elim nontriv | some => some
@@ -188,9 +207,22 @@
 (* Translate back a proof.                                                   *)
 (* ------------------------------------------------------------------------- *)
 
-(* FIXME OPTIMIZE!!!! *)
-fun mkproof asms just =
-  let fun addthms thm1 thm2 =
+(* FIXME OPTIMIZE!!!!
+   Addition/Multiplication need i*t representation rather than t+t+...
+
+Simplification may detect a contradiction 'prematurely' due to type
+information: n+1 <= 0 is simplified to False and does not need to be crossed
+with 0 <= n.
+*)
+local
+ exception FalseE of thm
+in
+fun mkproof sg asms just =
+  let val atoms = foldl (fn (ats,(lhs,_,_,rhs,_)) =>
+                            map fst lhs  union  (map fst rhs  union  ats))
+                        ([], mapfilter (LA_Data.decomp o concl_of) asms)
+
+      fun addthms thm1 thm2 =
         let val conj = thm1 RS (thm2 RS LA_Data.conjI)
         in the(get_first (fn th => Some(conj RS th) handle _ => None)
                          LA_Data.add_mono_thms)
@@ -200,13 +232,18 @@
         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
         in if n < 0 then mul(~n,thm) RS LA_Data.sym else mul(n,thm) end;
 
-      fun mk(Given i) = nth_elem(i,asms)
+      fun simp thm =
+        let val thm' = LA_Data.simp thm
+        in if LA_Data.is_False thm' then raise FalseE thm' else thm' end
+
+      fun mk(Asm i) = nth_elem(i,asms)
+        | mk(Nat(i)) = LA_Data.mk_nat_thm sg (nth_elem(i,atoms))
         | mk(Fwd(j,thm)) = mk j RS thm
-        | mk(Added(j1,j2)) = LA_Data.simp(addthms (mk j1) (mk j2))
+        | mk(Added(j1,j2)) = simp(addthms (mk j1) (mk j2))
         | mk(Multiplied(n,j)) = multn(n,mk j)
 
-  in LA_Data.simp(mk just) end;
-
+  in LA_Data.simp(mk just) handle FalseE thm => thm end
+end;
 
 fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i;
 
@@ -217,7 +254,7 @@
         and rhsa = map (coeff rhs) atoms
         val diff = map2 (op -) (rhsa,lhsa)
         val c = i-j
-        val just = Given k
+        val just = Asm k
     in case rel of
         "<="   => Some(Lineq(c,Le,diff,just))
        | "~<=" => Some(Lineq(1-c,Le,map (op ~) diff,Fwd(just,LA_Data.not_leD)))
@@ -229,52 +266,67 @@
     end
   end;
 
-fun abstract items =
+fun mknat pTs ixs (atom,i) =
+  if LA_Data.is_nat(pTs,atom)
+  then let val l = map (fn j => if j=i then 1 else 0) ixs
+       in Some(Lineq(0,Le,l,Nat(i))) end
+  else None
+
+fun abstract pTs items =
   let val atoms = foldl (fn (ats,((lhs,_,_,rhs,_),_)) =>
                             (map fst lhs) union ((map fst rhs) union ats))
                         ([],items)
-  in mapfilter (mklineq atoms) items end;
+      val ixs = 0 upto (length(atoms)-1)
+      val iatoms = atoms ~~ ixs
+  in mapfilter (mklineq atoms) items @ mapfilter (mknat pTs ixs) iatoms end;
 
 (* Ordinary refutation *)
-fun refute1_tac items =
-  let val lineqs = abstract items
+fun refute1_tac pTs items =
+  let val lineqs = abstract pTs items
   in case elim lineqs of
        None => K no_tac
      | Some(Lineq(_,_,_,j)) =>
-         resolve_tac [LA_Data.notI,LA_Data.ccontr] THEN'
-         METAHYPS (fn asms => rtac (mkproof asms j) 1)
+         fn i => fn state => 
+         let val sg = #sign(rep_thm state)
+         in resolve_tac [LA_Data.notI,LA_Data.ccontr] i THEN
+            METAHYPS (fn asms => rtac (mkproof sg asms j) 1) i
+         end state
   end;
 
 (* Double refutation caused by equality in conclusion *)
-fun refute2_tac items (rhs,i,_,lhs,j) nHs =
-  (case elim (abstract(items@[((rhs,i,"<",lhs,j),nHs)])) of
+fun refute2_tac pTs items (rhs,i,_,lhs,j) nHs =
+  (case elim (abstract pTs (items@[((rhs,i,"<",lhs,j),nHs)])) of
     None => K no_tac
   | Some(Lineq(_,_,_,j1)) =>
-      (case elim (abstract(items@[((lhs,j,"<",rhs,i),nHs)])) of
+      (case elim (abstract pTs (items@[((lhs,j,"<",rhs,i),nHs)])) of
         None => K no_tac
       | Some(Lineq(_,_,_,j2)) =>
-          rtac LA_Data.ccontr THEN' etac LA_Data.nat_neqE THEN'
-          METAHYPS (fn asms => rtac (mkproof asms j1) 1) THEN'
-          METAHYPS (fn asms => rtac (mkproof asms j2) 1) ));
+          fn i => fn state => 
+          let val sg = #sign(rep_thm state)
+          in rtac LA_Data.ccontr i THEN etac LA_Data.nat_neqE i THEN
+             METAHYPS (fn asms => rtac (mkproof sg asms j1) 1) i THEN
+             METAHYPS (fn asms => rtac (mkproof sg asms j2) 1) i
+          end state));
 
 (*
 Fast but very incomplete decider. Only premises and conclusions
 that are already (negated) (in)equations are taken into account.
 *)
 val lin_arith_tac = SUBGOAL (fn (A,n) =>
-  let val Hs = Logic.strip_assums_hyp A
+  let val pTs = rev(map snd (Logic.strip_params A))
+      val Hs = Logic.strip_assums_hyp A
       val nHs = length Hs
       val His = Hs ~~ (0 upto (nHs-1))
       val Hitems = mapfilter (fn (h,i) => case LA_Data.decomp h of
                                  None => None | Some(it) => Some(it,i)) His
   in case LA_Data.decomp(Logic.strip_assums_concl A) of
-       None => if null Hitems then no_tac else refute1_tac Hitems n
+       None => if null Hitems then no_tac else refute1_tac pTs Hitems n
      | Some(citem as (r,i,rel,l,j)) =>
          if rel = "="
-         then refute2_tac Hitems citem nHs n
+         then refute2_tac pTs Hitems citem nHs n
          else let val neg::rel0 = explode rel
                   val nrel = if neg = "~" then implode rel0 else "~"^rel
-              in refute1_tac (Hitems@[((r,i,nrel,l,j),nHs)]) n end
+              in refute1_tac pTs (Hitems@[((r,i,nrel,l,j),nHs)]) n end
   end);
 
 fun cut_lin_arith_tac thms i = cut_facts_tac thms i THEN lin_arith_tac i;