src/Provers/Arith/fast_lin_arith.ML
changeset 24630 351a308ab58d
parent 24112 6c4e7d17f9b0
child 24920 2a45e400fdad
equal deleted inserted replaced
24629:65947eb930fa 24630:351a308ab58d
    56   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    56   (*preprocessing, performed on a representation of subgoals as list of premises:*)
    57   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    57   val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    58 
    58 
    59   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    59   (*preprocessing, performed on the goal -- must do the same as 'pre_decomp':*)
    60   val pre_tac: Proof.context -> int -> tactic
    60   val pre_tac: Proof.context -> int -> tactic
    61   val number_of: IntInf.int * typ -> term
    61   val number_of: int * typ -> term
    62 
    62 
    63   (*the limit on the number of ~= allowed; because each ~= is split
    63   (*the limit on the number of ~= allowed; because each ~= is split
    64     into two cases, this can lead to an explosion*)
    64     into two cases, this can lead to an explosion*)
    65   val fast_arith_neq_limit: int Config.T
    65   val fast_arith_neq_limit: int Config.T
    66 end;
    66 end;
   152                 | Nat of int (* index of atom *)
   152                 | Nat of int (* index of atom *)
   153                 | LessD of injust
   153                 | LessD of injust
   154                 | NotLessD of injust
   154                 | NotLessD of injust
   155                 | NotLeD of injust
   155                 | NotLeD of injust
   156                 | NotLeDD of injust
   156                 | NotLeDD of injust
   157                 | Multiplied of IntInf.int * injust
   157                 | Multiplied of int * injust
   158                 | Multiplied2 of IntInf.int * injust
   158                 | Multiplied2 of int * injust
   159                 | Added of injust * injust;
   159                 | Added of injust * injust;
   160 
   160 
   161 datatype lineq = Lineq of IntInf.int * lineq_type * IntInf.int list * injust;
   161 datatype lineq = Lineq of int * lineq_type * int list * injust;
   162 
   162 
   163 (* ------------------------------------------------------------------------- *)
   163 (* ------------------------------------------------------------------------- *)
   164 (* Finding a (counter) example from the trace of a failed elimination        *)
   164 (* Finding a (counter) example from the trace of a failed elimination        *)
   165 (* ------------------------------------------------------------------------- *)
   165 (* ------------------------------------------------------------------------- *)
   166 (* Examples are represented as rational numbers,                             *)
   166 (* Examples are represented as rational numbers,                             *)
   176 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
   176 fun elim_eqns (Lineq (i, Le, cs, _)) = [(i, true, cs)]
   177   | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
   177   | elim_eqns (Lineq (i, Eq, cs, _)) = [(i, true, cs),(~i, true, map ~ cs)]
   178   | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
   178   | elim_eqns (Lineq (i, Lt, cs, _)) = [(i, false, cs)];
   179 
   179 
   180 (* PRE: ex[v] must be 0! *)
   180 (* PRE: ex[v] must be 0! *)
   181 fun eval ex v (a:IntInf.int,le,cs:IntInf.int list) =
   181 fun eval ex v (a, le, cs) =
   182   let
   182   let
   183     val rs = map Rat.rat_of_int cs;
   183     val rs = map Rat.rat_of_int cs;
   184     val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
   184     val rsum = fold2 (Rat.add oo Rat.mult) rs ex Rat.zero;
   185   in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
   185   in (Rat.mult (Rat.add (Rat.rat_of_int a) (Rat.neg rsum)) (Rat.inv (nth rs v)), le) end;
   186 (* If nth rs v < 0, le should be negated.
   186 (* If nth rs v < 0, le should be negated.
   330 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   330 fun is_trivial (Lineq(_,_,l,_)) = forall (fn i => i=0) l;
   331 
   331 
   332 fun is_answer (ans as Lineq(k,ty,l,_)) =
   332 fun is_answer (ans as Lineq(k,ty,l,_)) =
   333   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   333   case ty  of Eq => k <> 0 | Le => k > 0 | Lt => k >= 0;
   334 
   334 
   335 fun calc_blowup (l:IntInf.int list) =
   335 fun calc_blowup l =
   336   let val (p,n) = List.partition (curry (op <) 0) (List.filter (curry (op <>) 0) l)
   336   let val (p,n) = List.partition (curry (op <) 0) (List.filter (curry (op <>) 0) l)
   337   in (length p) * (length n) end;
   337   in length p * length n end;
   338 
   338 
   339 (* ------------------------------------------------------------------------- *)
   339 (* ------------------------------------------------------------------------- *)
   340 (* Main elimination code:                                                    *)
   340 (* Main elimination code:                                                    *)
   341 (*                                                                           *)
   341 (*                                                                           *)
   342 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   342 (* (1) Looks for immediate solutions (false assertions with no variables).   *)
   358   in extract [] end;
   358   in extract [] end;
   359 
   359 
   360 fun print_ineqs ineqs =
   360 fun print_ineqs ineqs =
   361   if !trace then
   361   if !trace then
   362      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   362      tracing(cat_lines(""::map (fn Lineq(c,t,l,_) =>
   363        IntInf.toString c ^
   363        string_of_int c ^
   364        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   364        (case t of Eq => " =  " | Lt=> " <  " | Le => " <= ") ^
   365        commas(map IntInf.toString l)) ineqs))
   365        commas(map string_of_int l)) ineqs))
   366   else ();
   366   else ();
   367 
   367 
   368 type history = (int * lineq list) list;
   368 type history = (int * lineq list) list;
   369 datatype result = Success of injust | Failure of history;
   369 datatype result = Success of injust | Failure of history;
   370 
   370 
   379   if null nontriv then Failure hist
   379   if null nontriv then Failure hist
   380   else
   380   else
   381   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   381   let val (eqs, noneqs) = List.partition (fn (Lineq(_,ty,_,_)) => ty=Eq) nontriv in
   382   if not (null eqs) then
   382   if not (null eqs) then
   383      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
   383      let val clist = Library.foldl (fn (cs,Lineq(_,_,l,_)) => l union cs) ([],eqs)
   384          val sclist = sort (fn (x,y) => IntInf.compare(abs(x),abs(y)))
   384          val sclist = sort (fn (x,y) => int_ord (abs x, abs y))
   385                            (List.filter (fn i => i<>0) clist)
   385                            (List.filter (fn i => i<>0) clist)
   386          val c = hd sclist
   386          val c = hd sclist
   387          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
   387          val (SOME(eq as Lineq(_,_,ceq,_)),othereqs) =
   388                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   388                extract_first (fn Lineq(_,_,l,_) => c mem l) eqs
   389          val v = find_index_eq c ceq
   389          val v = find_index_eq c ceq
   485         | mk (LessD j)            = trace_thm "L" (hd ([mk j] RL lessD))
   485         | mk (LessD j)            = trace_thm "L" (hd ([mk j] RL lessD))
   486         | mk (NotLeD j)           = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
   486         | mk (NotLeD j)           = trace_thm "NLe" (mk j RS LA_Logic.not_leD)
   487         | mk (NotLeDD j)          = trace_thm "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
   487         | mk (NotLeDD j)          = trace_thm "NLeD" (hd ([mk j RS LA_Logic.not_leD] RL lessD))
   488         | mk (NotLessD j)         = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
   488         | mk (NotLessD j)         = trace_thm "NL" (mk j RS LA_Logic.not_lessD)
   489         | mk (Added (j1, j2))     = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
   489         | mk (Added (j1, j2))     = simp (trace_thm "+" (addthms (mk j1) (mk j2)))
   490         | mk (Multiplied (n, j))  = (trace_msg ("*" ^ IntInf.toString n); trace_thm "*" (multn (n, mk j)))
   490         | mk (Multiplied (n, j))  = (trace_msg ("*" ^ string_of_int n); trace_thm "*" (multn (n, mk j)))
   491         | mk (Multiplied2 (n, j)) = simp (trace_msg ("**" ^ IntInf.toString n); trace_thm "**" (multn2 (n, mk j)))
   491         | mk (Multiplied2 (n, j)) = simp (trace_msg ("**" ^ string_of_int n); trace_thm "**" (multn2 (n, mk j)))
   492 
   492 
   493   in trace_msg "mkthm";
   493   in trace_msg "mkthm";
   494      let val thm = trace_thm "Final thm:" (mk just)
   494      let val thm = trace_thm "Final thm:" (mk just)
   495      in let val fls = simplify simpset' thm
   495      in let val fls = simplify simpset' thm
   496         in trace_thm "After simplification:" fls;
   496         in trace_thm "After simplification:" fls;
   505      end handle FalseE thm => trace_thm "False reached early:" thm
   505      end handle FalseE thm => trace_thm "False reached early:" thm
   506   end
   506   end
   507 end;
   507 end;
   508 
   508 
   509 fun coeff poly atom =
   509 fun coeff poly atom =
   510   AList.lookup (op aconv) poly atom |> the_default (0: integer);
   510   AList.lookup (op aconv) poly atom |> the_default 0;
   511 
       
   512 fun lcms ks = fold Integer.lcm ks 1;
       
   513 
   511 
   514 fun integ(rlhs,r,rel,rrhs,s,d) =
   512 fun integ(rlhs,r,rel,rrhs,s,d) =
   515 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   513 let val (rn,rd) = Rat.quotient_of_rat r and (sn,sd) = Rat.quotient_of_rat s
   516     val m = lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   514     val m = Integer.lcms(map (abs o snd o Rat.quotient_of_rat) (r :: s :: map snd rlhs @ map snd rrhs))
   517     fun mult(t,r) =
   515     fun mult(t,r) =
   518         let val (i,j) = Rat.quotient_of_rat r
   516         let val (i,j) = Rat.quotient_of_rat r
   519         in (t,i * (m div j)) end
   517         in (t,i * (m div j)) end
   520 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   518 in (m,(map mult rlhs, rn*(m div rd), rel, map mult rrhs, sn*(m div sd), d)) end
   521 
   519 
   545 (* Print (counter) example                                                   *)
   543 (* Print (counter) example                                                   *)
   546 (* ------------------------------------------------------------------------- *)
   544 (* ------------------------------------------------------------------------- *)
   547 
   545 
   548 fun print_atom((a,d),r) =
   546 fun print_atom((a,d),r) =
   549   let val (p,q) = Rat.quotient_of_rat r
   547   let val (p,q) = Rat.quotient_of_rat r
   550       val s = if d then IntInf.toString p else
   548       val s = if d then string_of_int p else
   551               if p = 0 then "0"
   549               if p = 0 then "0"
   552               else IntInf.toString p ^ "/" ^ IntInf.toString q
   550               else string_of_int p ^ "/" ^ string_of_int q
   553   in a ^ " = " ^ s end;
   551   in a ^ " = " ^ s end;
   554 
   552 
   555 fun produce_ex sds =
   553 fun produce_ex sds =
   556   curry (op ~~) sds
   554   curry (op ~~) sds
   557   #> map print_atom
   555   #> map print_atom