src/Provers/Arith/fast_lin_arith.ML
changeset 6128 2acc5d36610c
parent 6110 15c2b571225b
child 7551 8e934d1a9ac6
equal deleted inserted replaced
6127:ece970eb5850 6128:2acc5d36610c
    24   val conjI:		thm
    24   val conjI:		thm
    25   val ccontr:           thm (* (~ P ==> False) ==> P *)
    25   val ccontr:           thm (* (~ P ==> False) ==> P *)
    26   val neqE:             thm (* [| m ~= n; m < n ==> P; n < m ==> P |] ==> P *)
    26   val neqE:             thm (* [| m ~= n; m < n ==> P; n < m ==> P |] ==> P *)
    27   val notI:             thm (* (P ==> False) ==> ~ P *)
    27   val notI:             thm (* (P ==> False) ==> ~ P *)
    28   val not_lessD:        thm (* ~(m < n) ==> n <= m *)
    28   val not_lessD:        thm (* ~(m < n) ==> n <= m *)
       
    29   val not_leD:          thm (* ~(m <= n) ==> n < m *)
    29   val sym:		thm (* x = y ==> y = x *)
    30   val sym:		thm (* x = y ==> y = x *)
    30   val mk_Eq: thm -> thm
    31   val mk_Eq: thm -> thm
    31   val mk_Trueprop: term -> term
    32   val mk_Trueprop: term -> term
    32   val neg_prop: term -> term
    33   val neg_prop: term -> term
    33   val is_False: thm -> bool
    34   val is_False: thm -> bool
       
    35   val is_nat: typ list * term -> bool
       
    36   val mk_nat_thm: Sign.sg -> term -> thm
    34 end;
    37 end;
    35 (*
    38 (*
    36 mk_Eq(~in) = `in == False'
    39 mk_Eq(~in) = `in == False'
    37 mk_Eq(in) = `in == True'
    40 mk_Eq(in) = `in == True'
    38 where `in' is an (in)equality.
    41 where `in' is an (in)equality.
    39 
    42 
    40 neg_prop(t) = neg if t is wrapped up in Trueprop and
    43 neg_prop(t) = neg if t is wrapped up in Trueprop and
    41   nt is the (logically) negated version of t, where the negation
    44   nt is the (logically) negated version of t, where the negation
    42   of a negative term is the term itself (no double negation!);
    45   of a negative term is the term itself (no double negation!);
       
    46 
       
    47 is_nat(parameter-types,t) =  t:nat
       
    48 mk_nat_thm(t) = "0 <= t"
    43 *)
    49 *)
    44 
    50 
    45 signature LIN_ARITH_DATA =
    51 signature LIN_ARITH_DATA =
    46 sig
    52 sig
    47   val add_mono_thms:    thm list
    53   val add_mono_thms:    thm list ref
    48                             (* [| i rel1 j; m rel2 n |] ==> i + m rel3 j + n *)
    54                             (* [| i rel1 j; m rel2 n |] ==> i + m rel3 j + n *)
    49   val lessD:            thm (* m < n ==> Suc m <= n *)
    55   val lessD:            thm list ref (* m < n ==> m+1 <= n *)
    50   val not_leD:          thm (* ~(m <= n) ==> Suc n <= m *)
    56   val decomp:
    51   val decomp: term ->
    57     (term -> ((term * int)list * int * string * (term * int)list * int)option)
    52              ((term * int)list * int * string * (term * int)list * int)option
    58     ref
    53   val simp: thm -> thm
    59   val simp: (thm -> thm) ref
    54   val is_nat: typ list * term -> bool
       
    55   val mk_nat_thm: Sign.sg -> term -> thm
       
    56 end;
    60 end;
    57 (*
    61 (*
    58 decomp(`x Rel y') should yield (p,i,Rel,q,j)
    62 decomp(`x Rel y') should yield (p,i,Rel,q,j)
    59    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    63    where Rel is one of "<", "~<", "<=", "~<=" and "=" and
    60          p/q is the decomposition of the sum terms x/y into a list
    64          p/q is the decomposition of the sum terms x/y into a list
    61          of summand * multiplicity pairs and a constant summand.
    65          of summand * multiplicity pairs and a constant summand.
    62 
    66 
    63 simp must reduce contradictory <= to False.
    67 simp must reduce contradictory <= to False.
    64    It should also cancel common summands to keep <= reduced;
    68    It should also cancel common summands to keep <= reduced;
    65    otherwise <= can grow to massive proportions.
    69    otherwise <= can grow to massive proportions.
    66 
       
    67 is_nat(parameter-types,t) =  t:nat
       
    68 mk_nat_thm(t) = "0 <= t"
       
    69 *)
    70 *)
    70 
    71 
    71 signature FAST_LIN_ARITH =
    72 signature FAST_LIN_ARITH =
    72 sig
    73 sig
    73   val lin_arith_prover: Sign.sg -> thm list -> term -> thm option
    74   val lin_arith_prover: Sign.sg -> thm list -> term -> thm option
    88 
    89 
    89 datatype lineq_type = Eq | Le | Lt;
    90 datatype lineq_type = Eq | Le | Lt;
    90 
    91 
    91 datatype injust = Asm of int
    92 datatype injust = Asm of int
    92                 | Nat of int (* index of atom *)
    93                 | Nat of int (* index of atom *)
    93                 | Fwd of injust * thm
    94                 | LessD of injust
       
    95                 | NotLessD of injust
       
    96                 | NotLeD of injust
    94                 | Multiplied of int * injust
    97                 | Multiplied of int * injust
    95                 | Added of injust * injust;
    98                 | Added of injust * injust;
    96 
    99 
    97 datatype lineq = Lineq of int * lineq_type * int list * injust;
   100 datatype lineq = Lineq of int * lineq_type * int list * injust;
    98 
   101 
   248  exception FalseE of thm
   251  exception FalseE of thm
   249 in
   252 in
   250 fun mkthm sg asms just =
   253 fun mkthm sg asms just =
   251   let val atoms = foldl (fn (ats,(lhs,_,_,rhs,_)) =>
   254   let val atoms = foldl (fn (ats,(lhs,_,_,rhs,_)) =>
   252                             map fst lhs  union  (map fst rhs  union  ats))
   255                             map fst lhs  union  (map fst rhs  union  ats))
   253                         ([], mapfilter (LA_Data.decomp o concl_of) asms)
   256                         ([], mapfilter (!LA_Data.decomp o concl_of) asms)
   254 
   257 
   255       fun addthms thm1 thm2 =
   258       fun addthms thm1 thm2 =
   256         let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
   259         let val conj = thm1 RS (thm2 RS LA_Logic.conjI)
   257         in the(get_first (fn th => Some(conj RS th) handle _ => None)
   260         in the(get_first (fn th => Some(conj RS th) handle _ => None)
   258                          LA_Data.add_mono_thms)
   261                          (!LA_Data.add_mono_thms))
   259         end;
   262         end;
   260 
   263 
   261       fun multn(n,thm) =
   264       fun multn(n,thm) =
   262         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
   265         let fun mul(i,th) = if i=1 then th else mul(i-1, addthms thm th)
   263         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
   266         in if n < 0 then mul(~n,thm) RS LA_Logic.sym else mul(n,thm) end;
   264 
   267 
   265       fun simp thm =
   268       fun simp thm =
   266         let val thm' = LA_Data.simp thm
   269         let val thm' = !LA_Data.simp thm
   267         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
   270         in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end
   268 
   271 
   269       fun mk(Asm i) = nth_elem(i,asms)
   272       fun mk(Asm i) = nth_elem(i,asms)
   270         | mk(Nat(i)) = LA_Data.mk_nat_thm sg (nth_elem(i,atoms))
   273         | mk(Nat(i)) = LA_Logic.mk_nat_thm sg (nth_elem(i,atoms))
   271         | mk(Fwd(j,thm)) = mk j RS thm
   274         | mk(LessD(j)) = hd([mk j] RL !LA_Data.lessD)
       
   275         | mk(NotLeD(j)) = hd([mk j RS LA_Logic.not_leD] RL !LA_Data.lessD)
       
   276         | mk(NotLessD(j)) = mk j RS LA_Logic.not_lessD
   272         | mk(Added(j1,j2)) = simp(addthms (mk j1) (mk j2))
   277         | mk(Added(j1,j2)) = simp(addthms (mk j1) (mk j2))
   273         | mk(Multiplied(n,j)) = multn(n,mk j)
   278         | mk(Multiplied(n,j)) = multn(n,mk j)
   274 
   279 
   275   in LA_Data.simp(mk just) handle FalseE thm => thm end
   280   in !LA_Data.simp(mk just) handle FalseE thm => thm end
   276 end;
   281 end;
   277 
   282 
   278 fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i;
   283 fun coeff poly atom = case assoc(poly,atom) of None => 0 | Some i => i;
   279 
   284 
   280 fun mklineq atoms =
   285 fun mklineq atoms =
   285         val diff = map2 (op -) (rhsa,lhsa)
   290         val diff = map2 (op -) (rhsa,lhsa)
   286         val c = i-j
   291         val c = i-j
   287         val just = Asm k
   292         val just = Asm k
   288     in case rel of
   293     in case rel of
   289         "<="   => Some(Lineq(c,Le,diff,just))
   294         "<="   => Some(Lineq(c,Le,diff,just))
   290        | "~<=" => Some(Lineq(1-c,Le,map (op ~) diff,Fwd(just,LA_Data.not_leD)))
   295        | "~<=" => Some(Lineq(1-c,Le,map (op ~) diff,NotLeD(just)))
   291        | "<"   => Some(Lineq(c+1,Le,diff,Fwd(just,LA_Data.lessD)))
   296        | "<"   => Some(Lineq(c+1,Le,diff,LessD(just)))
   292        | "~<"  => Some(Lineq(~c,Le,map (op~) diff,Fwd(just,LA_Logic.not_lessD)))
   297        | "~<"  => Some(Lineq(~c,Le,map (op~) diff,NotLessD(just)))
   293        | "="   => Some(Lineq(c,Eq,diff,just))
   298        | "="   => Some(Lineq(c,Eq,diff,just))
   294        | "~="  => None
   299        | "~="  => None
   295        | _     => sys_error("mklineq" ^ rel)   
   300        | _     => sys_error("mklineq" ^ rel)   
   296     end
   301     end
   297   end;
   302   end;
   298 
   303 
   299 fun mknat pTs ixs (atom,i) =
   304 fun mknat pTs ixs (atom,i) =
   300   if LA_Data.is_nat(pTs,atom)
   305   if LA_Logic.is_nat(pTs,atom)
   301   then let val l = map (fn j => if j=i then 1 else 0) ixs
   306   then let val l = map (fn j => if j=i then 1 else 0) ixs
   302        in Some(Lineq(0,Le,l,Nat(i))) end
   307        in Some(Lineq(0,Le,l,Nat(i))) end
   303   else None
   308   else None
   304 
   309 
   305 fun abstract pTs items =
   310 fun abstract pTs items =
   343     state;
   348     state;
   344 
   349 
   345 fun prove(pTs,Hs,concl) =
   350 fun prove(pTs,Hs,concl) =
   346 let val nHs = length Hs
   351 let val nHs = length Hs
   347     val ixHs = Hs ~~ (0 upto (nHs-1))
   352     val ixHs = Hs ~~ (0 upto (nHs-1))
   348     val Hitems = mapfilter (fn (h,i) => case LA_Data.decomp h of
   353     val Hitems = mapfilter (fn (h,i) => case !LA_Data.decomp h of
   349                                  None => None | Some(it) => Some(it,i)) ixHs
   354                                  None => None | Some(it) => Some(it,i)) ixHs
   350 in case LA_Data.decomp concl of
   355 in case !LA_Data.decomp concl of
   351      None => if null Hitems then [] else refute1(pTs,Hitems)
   356      None => if null Hitems then [] else refute1(pTs,Hitems)
   352    | Some(citem as (r,i,rel,l,j)) =>
   357    | Some(citem as (r,i,rel,l,j)) =>
   353        if rel = "="
   358        if rel = "="
   354        then refute2(pTs,Hitems,citem,nHs)
   359        then refute2(pTs,Hitems,citem,nHs)
   355        else let val neg::rel0 = explode rel
   360        else let val neg::rel0 = explode rel