src/Provers/Arith/fast_lin_arith.ML
changeset 16358 2e2a506553a3
parent 15965 f422f8283491
child 16458 4c6fd0c01d28
equal deleted inserted replaced
16357:f1275d2a1dee 16358:2e2a506553a3
    13 
    13 
    14     lin_arith_prover: Sign.sg -> simpset -> term -> thm option
    14     lin_arith_prover: Sign.sg -> simpset -> term -> thm option
    15 
    15 
    16 Only take premises and conclusions into account that are already (negated)
    16 Only take premises and conclusions into account that are already (negated)
    17 (in)equations. lin_arith_prover tries to prove or disprove the term.
    17 (in)equations. lin_arith_prover tries to prove or disprove the term.
    18 
       
    19 FIXME: convert to IntInf.int throughout. 
       
    20 *)
    18 *)
    21 
    19 
    22 (* Debugging: set Fast_Arith.trace *)
    20 (* Debugging: set Fast_Arith.trace *)
    23 
    21 
    24 (*** Data needed for setting up the linear arithmetic package ***)
    22 (*** Data needed for setting up the linear arithmetic package ***)
    53 
    51 
    54 signature LIN_ARITH_DATA =
    52 signature LIN_ARITH_DATA =
    55 sig
    53 sig
    56   val decomp:
    54   val decomp:
    57     Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option
    55     Sign.sg -> term -> ((term*rat)list * rat * string * (term*rat)list * rat * bool)option
    58   val number_of: int * typ -> term
    56   val number_of: IntInf.int * typ -> term
    59 end;
    57 end;
    60 (*
    58 (*
    61 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    59 decomp(`x Rel y') should yield (p,i,Rel,q,j,d)
    62    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    60    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    63          p/q is the decomposition of the sum terms x/y into a list
    61          p/q is the decomposition of the sum terms x/y into a list
   139                 | Nat of int (* index of atom *)
   137                 | Nat of int (* index of atom *)
   140                 | LessD of injust
   138                 | LessD of injust
   141                 | NotLessD of injust
   139                 | NotLessD of injust
   142                 | NotLeD of injust
   140                 | NotLeD of injust
   143                 | NotLeDD of injust
   141                 | NotLeDD of injust
   144                 | Multiplied of int * injust
   142                 | Multiplied of IntInf.int * injust
   145                 | Multiplied2 of int * injust
   143                 | Multiplied2 of IntInf.int * injust
   146                 | Added of injust * injust;
   144                 | Added of injust * injust;
   147 
   145 
   148 datatype lineq = Lineq of int * lineq_type * int list * injust;
   146 datatype lineq = Lineq of IntInf.int * lineq_type * IntInf.int list * injust;
   149 
   147 
   150 fun el 0 (h::_) = h
   148 fun el 0 (h::_) = h
   151   | el n (_::t) = el (n - 1) t
   149   | el n (_::t) = el (n - 1) t
   152   | el _ _  = sys_error "el";
   150   | el _ _  = sys_error "el";
   153 
   151 
   169   | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs;
   167   | elim_eqns(ineqs,Lineq(i,Lt,cs,_)) = (i,false,cs)::ineqs;
   170 
   168 
   171 val rat0 = rat_of_int 0;
   169 val rat0 = rat_of_int 0;
   172 
   170 
   173 (* PRE: ex[v] must be 0! *)
   171 (* PRE: ex[v] must be 0! *)
   174 fun eval (ex:rat list) v (a:int,le,cs:int list) =
   172 fun eval (ex:rat list) v (a:IntInf.int,le,cs:IntInf.int list) =
   175   let val rs = map rat_of_int cs
   173   let val rs = map rat_of_intinf cs
   176       val rsum = Library.foldl ratadd (rat0,map ratmul (rs ~~ ex))
   174       val rsum = Library.foldl ratadd (rat0,map ratmul (rs ~~ ex))
   177   in (ratmul(ratadd(rat_of_int a,ratneg rsum), ratinv(el v rs)), le) end;
   175   in (ratmul(ratadd(rat_of_intinf a,ratneg rsum), ratinv(el v rs)), le) end;
   178 (* If el v rs < 0, le should be negated.
   176 (* If el v rs < 0, le should be negated.
   179    Instead this swap is taken into account in ratrelmin2.
   177    Instead this swap is taken into account in ratrelmin2.
   180 *)
   178 *)
   181 
   179 
   182 fun ratge0 r = fst(rep_rat r) >= 0;
   180 fun ratge0 r = fst(rep_rat r) >= 0;
   256        in pick_vars discr (ineqs',ex') end
   254        in pick_vars discr (ineqs',ex') end
   257   end;
   255   end;
   258 
   256 
   259 fun findex0 discr n lineqs =
   257 fun findex0 discr n lineqs =
   260   let val ineqs = Library.foldl elim_eqns ([],lineqs)
   258   let val ineqs = Library.foldl elim_eqns ([],lineqs)
   261       val rineqs = map (fn (a,le,cs) => (rat_of_int a, le, map rat_of_int cs))
   259       val rineqs = map (fn (a,le,cs) => (rat_of_intinf a, le, map rat_of_intinf cs))
   262                        ineqs
   260                        ineqs
   263   in pick_vars discr (rineqs,replicate n rat0) end;
   261   in pick_vars discr (rineqs,replicate n rat0) end;
   264 
   262 
   265 (* ------------------------------------------------------------------------- *)
   263 (* ------------------------------------------------------------------------- *)
   266 (* End of counter example finder. The actual decision procedure starts here. *)
   264 (* End of counter example finder. The actual decision procedure starts here. *)
   299 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   297 (* If they're both inequalities, 1st coefficient must be +ve, 2nd -ve.       *)
   300 (* ------------------------------------------------------------------------- *)
   298 (* ------------------------------------------------------------------------- *)
   301 
   299 
   302 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   300 fun elim_var v (i1 as Lineq(k1,ty1,l1,just1)) (i2 as Lineq(k2,ty2,l2,just2)) =
   303   let val c1 = el v l1 and c2 = el v l2
   301   let val c1 = el v l1 and c2 = el v l2
   304       val m = IntInf.toInt (lcm(IntInf.fromInt (abs c1), IntInf.fromInt(abs c2)))
   302       val m = lcm(abs c1, abs c2)
   305       val m1 = m div (abs c1) and m2 = m div (abs c2)
   303       val m1 = m div (abs c1) and m2 = m div (abs c2)
   306       val (n1,n2) =
   304       val (n1,n2) =
   307         if (c1 >= 0) = (c2 >= 0)
   305         if (c1 >= 0) = (c2 >= 0)
   308         then if ty1 = Eq then (~m1,m2)
   306         then if ty1 = Eq then (~m1,m2)
   309              else if ty2 = Eq then (m1,~m2)
   307              else if ty2 = Eq then (m1,~m2)
   320 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   318 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   321 
   319 
   322 fun is_answer (ans as Lineq(k,ty,l,_)) =
   320 fun is_answer (ans as Lineq(k,ty,l,_)) =
   323   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   321   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   324 
   322 
   325 fun calc_blowup l =
   323 fun calc_blowup (l:IntInf.int list) =
   326   let val (p,n) = List.partition (apl(0,op<)) (List.filter (apl(0,op<>)) l)
   324   let val (p,n) = List.partition (apl(0,op<)) (List.filter (apl(0,op<>)) l)
   327   in (length p) * (length n) end;
   325   in (length p) * (length n) end;
   328 
   326 
   329 (* ------------------------------------------------------------------------- *)
   327 (* ------------------------------------------------------------------------- *)
   330 (* Main elimination code:                                                    *)
   328 (* Main elimination code:                                                    *)
   348   in extract [] end;
   346   in extract [] end;
   349 
   347 
   350 fun print_ineqs ineqs =
   348 fun print_ineqs ineqs =
   351   if !trace then
   349   if !trace then
   352      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   350      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   353        string_of_int c ^
   351        IntInf.toString c ^
   354        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   352        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   355        commas(map string_of_int l)) ineqs))
   353        commas(map IntInf.toString l)) ineqs))
   356   else ();
   354   else ();
   357 
   355 
   358 type history = (int * lineq list) list;
   356 type history = (int * lineq list) list;
   359 datatype result = Success of injust | Failure of history;
   357 datatype result = Success of injust | Failure of history;
   360 
   358 
   369   if null nontriv then Failure(hist)
   367   if null nontriv then Failure(hist)
   370   else
   368   else
   371   let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   369   let val (eqs,noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   372   if not(null eqs) then
   370   if not(null eqs) then
   373      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
   371      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
   374          val sclist = sort (fn (x,y) => int_ord(abs(x),abs(y)))
   372          val sclist = sort (fn (x,y) => IntInf.compare(abs(x),abs(y)))
   375                            (List.filter (fn i => i<>0) clist)
   373                            (List.filter (fn i => i<>0) clist)
   376          val c = hd sclist
   374          val c = hd sclist
   377          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
   375          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
   378                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   376                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   379          val v = find_index_eq c ceq
   377          val v = find_index_eq c ceq
   476         | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD))
   474         | mk(LessD(j)) = trace_thm "L" (hd([mk j] RL lessD))
   477         | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
   475         | mk(NotLeD(j)) = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
   478         | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD))
   476         | mk(NotLeDD(j)) = trace_thm "NLeD" (hd([mk j RS LA_Logic.not_leD] RL lessD))
   479         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
   477         | mk(NotLessD(j)) = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
   480         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
   478         | mk(Added(j1,j2)) = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
   481         | mk(Multiplied(n,j)) = (trace_msg("*"^string_of_int n); trace_thm "*" (multn(n,mk j)))
   479         | mk(Multiplied(n,j)) = (trace_msg("*"^IntInf.toString n); trace_thm "*" (multn(n,mk j)))
   482         | mk(Multiplied2(n,j)) = simp (trace_msg("**"^string_of_int n); trace_thm "**" (multn2(n,mk j)))
   480         | mk(Multiplied2(n,j)) = simp (trace_msg("**"^IntInf.toString n); trace_thm "**" (multn2(n,mk j)))
   483 
   481 
   484   in trace_msg "mkthm";
   482   in trace_msg "mkthm";
   485      let val thm = trace_thm "Final thm:" (mk just)
   483      let val thm = trace_thm "Final thm:" (mk just)
   486      in let val fls = simplify simpset thm
   484      in let val fls = simplify simpset thm
   487         in trace_thm "After simplification:" fls;
   485         in trace_thm "After simplification:" fls;
   495         end
   493         end
   496      end handle FalseE thm => (trace_thm "False reached early:" thm; thm)
   494      end handle FalseE thm => (trace_thm "False reached early:" thm; thm)
   497   end
   495   end
   498 end;
   496 end;
   499 
   497 
   500 fun coeff poly atom = case assoc(poly,atom) of NONE => 0 | SOME i => i;
   498 fun coeff poly atom : IntInf.int =
   501 
   499   case assoc(poly,atom) of NONE => 0 | SOME i => i;
   502 fun lcms_intinf is = Library.foldl lcm (1, is);
   500 
   503 fun lcms is = IntInf.toInt (lcms_intinf (map IntInf.fromInt is));
   501 fun lcms is = Library.foldl lcm (1, is);
   504 
   502 
   505 fun integ(rlhs,r,rel,rrhs,s,d) =
   503 fun integ(rlhs,r,rel,rrhs,s,d) =
   506 let val (rn,rd) = pairself IntInf.toInt (rep_rat r) and (sn,sd) = pairself IntInf.toInt (rep_rat s)
   504 let val (rn,rd) = rep_rat r and (sn,sd) = rep_rat s
   507     val m = IntInf.toInt (lcms_intinf(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs)))
   505     val m = lcms(map (abs o snd o rep_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   508     fun mult(t,r) = 
   506     fun mult(t,r) = 
   509         let val (i,j) =  pairself IntInf.toInt (rep_rat r) 
   507         let val (i,j) =  (rep_rat r) 
   510         in (t,i * (m div j)) end
   508         in (t,i * (m div j)) end
   511 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   509 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   512 
   510 
   513 fun mklineq n atoms =
   511 fun mklineq n atoms =
   514   fn (item,k) =>
   512   fn (item,k) =>