src/HOL/Integ/barith.ML
changeset 15298 a5bea99352d6
parent 15297 0aff5d912422
child 15299 576fd0b65ed8
equal deleted inserted replaced
15297:0aff5d912422 15298:a5bea99352d6
     1 (**************************************************************)
       
     2 (*                                                            *)
       
     3 (*                                                            *)
       
     4 (*          Trying to implement an Bounded arithmetic         *)
       
     5 (*           Chaieb Amine                                     *)
       
     6 (*                                                            *)
       
     7 (**************************************************************)
       
     8   
       
     9 signature BARITH = 
       
    10 sig
       
    11   val barith_tac : int -> tactic
       
    12   val setup      : (theory -> theory) list
       
    13   
       
    14 end;
       
    15 
       
    16 
       
    17 structure Barith =
       
    18 struct
       
    19 
       
    20 (* Theorems we use from Barith.thy*)
       
    21 val abs_const = thm "abs_const";
       
    22 val abs_var = thm "abs_var";
       
    23 val abs_neg = thm "abs_neg";
       
    24 val abs_add = thm "abs_add";
       
    25 val abs_sub = thm "abs_sub";
       
    26 val abs_sub_x = thm "abs_sub_x";
       
    27 val abs_mul = thm "abs_mul";
       
    28 val abs_mul_x = thm "abs_mul_x";
       
    29 val subinterval = thm "subinterval";
       
    30 val imp_commute = thm "imp_commute";
       
    31 val imp_simplify = thm "imp_simplify";
       
    32 
       
    33 exception NORMCONJ of string;
       
    34 
       
    35 fun interval_of_conj t = case t of
       
    36  Const("op &",_) $
       
    37   (t1 as (Const("op <=",_) $ l1 $(x as Free(xn,xT))))$
       
    38   (t2 as (Const("op <=",_) $ y $ u1)) => 
       
    39       if (x = y andalso type_of x = HOLogic.intT) 
       
    40         then [(x,(l1,u1))]
       
    41         else (interval_of_conj t1) union (interval_of_conj t2)
       
    42 | Const("op &",_) $(t1 as (Const("op <=",_) $ y $ u1))$
       
    43   (t2 as (Const("op <=",_) $ l1 $(x as Free(xn,xT)))) =>
       
    44       if (x = y andalso type_of x = HOLogic.intT) 
       
    45         then [(x,(l1,u1))]
       
    46         else (interval_of_conj t1) union (interval_of_conj t2)
       
    47 |(Const("op <=",_) $ l $(x as Free(xn,xT))) => [(x,(l,HOLogic.false_const))]
       
    48 |(Const("op <=",_) $ (x as Free(xn,xT))$ u) => [(x,(HOLogic.false_const,u))]
       
    49 |Const("op &",_)$t1$t2 => (interval_of_conj t1) union (interval_of_conj t2)
       
    50 |_ => raise (NORMCONJ "Not in normal form - unknown conjunct");
       
    51 
       
    52 
       
    53 (* The input to this function should be a list *)
       
    54 (*of meta-implications of the following form:*)
       
    55 (* l1 <= x1 & x1 <= u1 ==> ... ==> ln <= xn & xn <= un*)
       
    56 (* the output will be a list of Var*interval*)
       
    57 
       
    58 val iT = HOLogic.intT;
       
    59 fun  maxterm (Const("False",_)) t = t
       
    60     |maxterm t (Const("False",_)) = t 
       
    61     |maxterm t1 t2 = Const("HOL.max",iT --> iT --> iT)$t1$t2;
       
    62 
       
    63 fun  minterm (Const("False",_)) t = t
       
    64     |minterm t (Const("False",_)) = t
       
    65     |minterm t1 t2 = Const("HOL.min",iT --> iT --> iT)$t1$t2;
       
    66 
       
    67 fun intervals_of_premise p =  
       
    68   let val ps = map HOLogic.dest_Trueprop (Logic.strip_imp_prems p)
       
    69       fun tight [] = []
       
    70          |tight ((x,(Const("False",_),Const("False",_)))::ls) = tight ls
       
    71          |tight ((x,(l as Const("False",_),u))::ls) = 
       
    72 	   let val ls' = tight ls in
       
    73 	   case assoc (ls',x) of
       
    74 	   None => (x,(l,u))::ls'
       
    75 	   |Some (l',u') => 
       
    76 	   let 
       
    77             val ln = l'
       
    78             val un = 
       
    79 	     if (CooperDec.is_numeral u) andalso (CooperDec.is_numeral u') 
       
    80 	     then CooperDec.mk_numeral 
       
    81 		 (Int.min (CooperDec.dest_numeral u,CooperDec.dest_numeral u'))
       
    82 	     else (minterm u u')
       
    83 	   in (x,(ln,un))::(filter (fn p => not (fst p = x)) ls')
       
    84 	   end
       
    85           end
       
    86          |tight ((x,(l,u as Const("False",_)))::ls) = 
       
    87 	   let val ls' = tight ls in
       
    88 	   case assoc (ls',x) of
       
    89 	   None => (x,(l,u))::ls'
       
    90 	   |Some (l',u') => 
       
    91 	   let 
       
    92             val ln = 
       
    93 	      if (CooperDec.is_numeral l) andalso (CooperDec.is_numeral l') 
       
    94 	      then CooperDec.mk_numeral 
       
    95 		(Int.max (CooperDec.dest_numeral l,CooperDec.dest_numeral l')) 
       
    96 	      else (maxterm l l')
       
    97             val un = u'
       
    98 	   in (x,(ln,un))::(filter (fn p => not (fst p = x)) ls')
       
    99 	   end
       
   100           end
       
   101          |tight ((x,(l,u))::ls) = 
       
   102 	   let val ls' = tight ls in
       
   103 	     case assoc (ls',x) of
       
   104 	      None => (x,(l,u))::ls'
       
   105 	     |Some (l',u') => let val ln = if (CooperDec.is_numeral l) andalso (CooperDec.is_numeral l') then CooperDec.mk_numeral (Int.max (CooperDec.dest_numeral l,CooperDec.dest_numeral l')) else (maxterm l l')
       
   106 		 val un = if (CooperDec.is_numeral u) andalso (CooperDec.is_numeral u') then CooperDec.mk_numeral (Int.min (CooperDec.dest_numeral u,CooperDec.dest_numeral u')) else (minterm u u')
       
   107 		   in (x,(ln,un))::(filter (fn p => not (fst p = x)) ls')
       
   108 		   end
       
   109            end 
       
   110   in tight (foldr (fn (p,l) => (interval_of_conj p) union l) (ps,[]))
       
   111 end ;
       
   112 
       
   113 fun exp_of_concl p = case p of
       
   114   Const("op &",_) $
       
   115   (Const("op <=",_) $ l $ e)$
       
   116   (Const("op <=",_) $ e' $ u) => 
       
   117      if e = e' then [(e,(Some l,Some u))]
       
   118      else raise NORMCONJ "Conclusion not in normal form-- different exp in conj"
       
   119 |Const("op &",_) $
       
   120   (Const("op <=",_) $ e' $ u)$
       
   121   (Const("op <=",_) $ l $ e) => 
       
   122      if e = e' then [(e,(Some l,Some u))] 
       
   123      else raise NORMCONJ "Conclusion not in normal form-- different exp in conj"
       
   124 |(Const("op <=",_) $ e $ u) =>
       
   125   if (CooperDec.is_numeral u) then [(e,(None,Some u))]
       
   126   else 
       
   127     if (CooperDec.is_numeral e) then [(u,(Some e,None))] 
       
   128     else raise NORMCONJ "Bounds has to be numerals" 
       
   129 |(Const("op &",_)$a$b) => (exp_of_concl a) @ (exp_of_concl b)
       
   130 |_ => raise NORMCONJ "Conclusion not in normal form---unknown connective";
       
   131 
       
   132 
       
   133 fun strip_problem p = 
       
   134 let 
       
   135   val is = intervals_of_premise p
       
   136   val e = exp_of_concl ((HOLogic.dest_Trueprop o Logic.strip_imp_concl) p)
       
   137 in (is,e)
       
   138 end;
       
   139 
       
   140 
       
   141 
       
   142 
       
   143 (*Abstract interpretation of Intervals over theorems *)
       
   144 exception ABSEXP of string;
       
   145 
       
   146 fun decomp_absexp sg is e = case e of
       
   147  Free(xn,_) => ([], fn [] => case assoc (is,e) of 
       
   148    Some (l,u) => instantiate' [] 
       
   149      (map (fn a => Some (cterm_of sg a)) [l,e,u]) abs_var
       
   150   |_ => raise ABSEXP ("No Interval for Variable   " ^ xn) )
       
   151 |Const("op +",_) $ e1 $ e2 => 
       
   152   ([e1,e2], fn [th1,th2] => [th1,th2] MRS abs_add)
       
   153 |Const("op -",_) $ e1 $ e2 => 
       
   154   if e1 = e2 then 
       
   155     ([e1],fn [th] => th RS abs_sub_x)
       
   156   else
       
   157     ([e1,e2], fn [th1,th2] => [th1,th2] MRS abs_sub)
       
   158 |Const("op *",_) $ e1 $ e2 => 
       
   159   if e1 = e2 then 
       
   160     ([e1],fn [th] => th RS abs_mul_x)
       
   161   else
       
   162   ([e1,e2], fn [th1,th2] => [th1,th2] MRS abs_mul)
       
   163 |Const("op uminus",_) $ e' => 
       
   164   ([e'], fn [th] => th RS abs_neg)
       
   165 |_ => if CooperDec.is_numeral e then
       
   166     ([], fn [] => instantiate' [] [Some (cterm_of sg e)] abs_const) 
       
   167         else raise ABSEXP "Unknown arithmetical expression";
       
   168 
       
   169 fun absexp sg is (e,(lo,uo)) = case (lo,uo) of
       
   170   (Some l, Some u) =>
       
   171   let 
       
   172     val th1 = CooperProof.thm_of sg (decomp_absexp sg is) e
       
   173     val th2 = instantiate' [] [None,None,None,Some (cterm_of sg l),Some (cterm_of sg u)] subinterval
       
   174     val ss = (simpset_of (theory "Presburger")) addsimps [max_def,min_def]
       
   175     val my_ss = HOL_basic_ss addsimps [imp_commute, imp_simplify]
       
   176     val th' = th1
       
   177     val th = th' RS th2
       
   178   in th
       
   179   end 
       
   180 |(None, Some u) => 
       
   181   let 
       
   182     val th1 = CooperProof.thm_of sg (decomp_absexp sg is) e
       
   183     val Const("op &",_)$
       
   184       (Const("op <=",_)$l$_)$_= (HOLogic.dest_Trueprop o concl_of) th1
       
   185     val th2 = instantiate' [] [None,None,None,Some (cterm_of sg l),Some (cterm_of sg u)] subinterval
       
   186     val ss = (simpset_of (theory "Presburger")) addsimps [max_def,min_def]
       
   187     val my_ss = HOL_basic_ss addsimps [imp_commute, imp_simplify]
       
   188     val th' = th1
       
   189     val th = th' RS th2
       
   190   in th RS conjunct2
       
   191   end 
       
   192 
       
   193 |(Some l, None) => let 
       
   194     val th1 = CooperProof.thm_of sg (decomp_absexp sg is) e
       
   195     val Const("op &",_)$_$
       
   196       (Const("op <=",_)$_$u)= (HOLogic.dest_Trueprop o concl_of) th1
       
   197     val th2 = instantiate' [] [None,None,None,Some (cterm_of sg l),Some (cterm_of sg u)] subinterval
       
   198     val ss = (simpset_of (theory "Presburger")) addsimps [max_def,min_def]
       
   199     val my_ss = HOL_basic_ss addsimps [imp_commute, imp_simplify]
       
   200     val th' = th1
       
   201     val th = th' RS th2
       
   202   in th RS conjunct1
       
   203   end 
       
   204 
       
   205 |(None,None) => raise ABSEXP "No bounds for conclusion";
       
   206 
       
   207 fun free_occ e = case e of
       
   208  Free(_,i) => if i = HOLogic.intT then 1 else 0
       
   209 |f$a => (free_occ f) + (free_occ a)
       
   210 |Abs(_,_,p) => free_occ p
       
   211 |_ => 0;
       
   212 
       
   213 
       
   214 (*
       
   215 fun simp_exp sg p = 
       
   216   let val (is,(e,(l,u))) = strip_problem p
       
   217       val th = absexp sg is (e,(l,u))
       
   218       val _ = prth th
       
   219   in (th, free_occ e)
       
   220 end;
       
   221 *)
       
   222 
       
   223 fun simp_exp sg p = 
       
   224   let val (is,es) = strip_problem p
       
   225       val ths = map (absexp sg is) es
       
   226       val n = foldr (fn ((e,(_,_)),x) => (free_occ e) + x) (es,0)
       
   227   in (ths, n)
       
   228 end;
       
   229 
       
   230 
       
   231 
       
   232 (* ============================ *)
       
   233 (*      The barith Tactic       *)
       
   234 (* ============================ *)
       
   235 
       
   236 (*
       
   237 fun barith_tac i = ObjectLogic.atomize_tac i THEN (fn st =>
       
   238   let
       
   239     fun assm_tac n j = REPEAT_DETERM_N n ((assume_tac j) ORELSE (simple_arith_tac j))
       
   240     val g = BasisLibrary.List.nth (prems_of st, i - 1)
       
   241     val sg = sign_of_thm st
       
   242     val ss = (simpset_of (the_context())) addsimps [max_def,min_def]
       
   243     val (th,n) = simp_exp sg g
       
   244   in (rtac th i 
       
   245 	THEN assm_tac n i  
       
   246 	THEN (TRY (REPEAT_DETERM_N 2 (simp_tac ss i)))) st
       
   247 end);
       
   248 
       
   249 *)
       
   250 
       
   251 
       
   252 fun barith_tac i = ObjectLogic.atomize_tac i THEN (fn st =>
       
   253   let
       
   254     fun assm_tac n j = REPEAT_DETERM_N n ((assume_tac j) ORELSE (simple_arith_tac j))
       
   255     val g = BasisLibrary.List.nth (prems_of st, i - 1)
       
   256     val sg = sign_of_thm st
       
   257     val ss = (simpset_of (theory "Barith")) addsimps [max_def,min_def]
       
   258     val cg = cterm_of sg g
       
   259     val mybinarith =
       
   260       map thm ["Pls_0_eq", "Min_1_eq",
       
   261                "bin_pred_Pls", "bin_pred_Min", "bin_pred_1",
       
   262 	       "bin_pred_0",            "bin_succ_Pls", "bin_succ_Min",
       
   263 	       "bin_succ_1", "bin_succ_0",
       
   264                "bin_add_Pls", "bin_add_Min", "bin_add_BIT_0",
       
   265 	       "bin_add_BIT_10",
       
   266                "bin_add_BIT_11", "bin_minus_Pls", "bin_minus_Min",
       
   267 	       "bin_minus_1",
       
   268                "bin_minus_0", "bin_mult_Pls", "bin_mult_Min",
       
   269 	       "bin_mult_1", "bin_mult_0",
       
   270                "bin_add_Pls_right", "bin_add_Min_right",
       
   271 	       "abs_zero", "abs_one",
       
   272                "eq_number_of_eq",
       
   273                "iszero_number_of_Pls", "nonzero_number_of_Min",
       
   274 	       "iszero_number_of_0", "iszero_number_of_1",
       
   275                "less_number_of_eq_neg",
       
   276                "not_neg_number_of_Pls", "neg_number_of_Min",
       
   277 	       "neg_number_of_BIT",
       
   278                "le_number_of_eq"]
       
   279 
       
   280      val myringarith =
       
   281        [number_of_add RS sym, number_of_minus RS sym,
       
   282 	diff_number_of_eq, number_of_mult RS sym,
       
   283 	thm "zero_eq_Numeral0_nring", thm "one_eq_Numeral1_nring"]
       
   284 
       
   285      val mynatarith =
       
   286        [thm "zero_eq_Numeral0_nat", thm "one_eq_Numeral1_nat",
       
   287 	thm "add_nat_number_of", thm "diff_nat_number_of",
       
   288 	thm "mult_nat_number_of", thm "eq_nat_number_of", thm
       
   289 	  "less_nat_number_of"]
       
   290 	 
       
   291      val mypowerarith =
       
   292        [thm "nat_number_of", thm "zpower_number_of_even", thm
       
   293 	  "zpower_number_of_odd", thm "zpower_Pls", thm "zpower_Min"]
       
   294 
       
   295      val myiflet = [if_False, if_True, thm "Let_def"]
       
   296      val myifletcongs = [if_weak_cong, let_weak_cong]
       
   297 
       
   298      val mysimpset = HOL_basic_ss 
       
   299 	 addsimps mybinarith 
       
   300 	 addsimps myringarith
       
   301          addsimps mynatarith addsimps mypowerarith
       
   302          addsimps myiflet addsimps simp_thms
       
   303          addcongs myifletcongs
       
   304 
       
   305     val simpset0 = HOL_basic_ss 
       
   306 	addsimps [thm "z_less_imp_le1", thm "z_eq_imp_le_conj"] 
       
   307     val pre_thm = Seq.hd (EVERY (map TRY 
       
   308 	 [simp_tac simpset0 1, simp_tac mysimpset 1]) 
       
   309 			    (trivial cg))
       
   310     val tac = case (prop_of pre_thm) of
       
   311         Const ("==>", _) $ t1 $ _ =>
       
   312       let  
       
   313          val (ths,n) = simp_exp sg t1
       
   314          val cn = length ths - 1
       
   315          fun conjIs thn j = EVERY (map (rtac conjI) (j upto (thn + j - 1)))
       
   316          fun thtac thms j = EVERY (map 
       
   317 	(fn t => rtac t j THEN assm_tac n j  
       
   318 	THEN (TRY (REPEAT_DETERM_N 2 (simp_tac ss j)))) thms)
       
   319       in ((conjIs cn i) THEN (thtac ths i))
       
   320       end
       
   321      |_ => assume_tac i
       
   322      in (tac st)
       
   323 end);
       
   324 
       
   325 fun barith_args meth =
       
   326  let val parse_flag = 
       
   327          Args.$$$ "no_quantify" >> K (apfst (K false))
       
   328       || Args.$$$ "abs" >> K (apsnd (K true));
       
   329  in
       
   330    Method.simple_args 
       
   331   (Scan.optional (Args.$$$ "(" |-- Scan.repeat1 parse_flag --| Args.$$$ ")") []
       
   332  >>
       
   333     curry (foldl op |>) (true, false))
       
   334     (fn (q,a) => fn _ => meth 1)
       
   335   end;
       
   336 
       
   337 fun barith_method i = Method.METHOD (fn facts =>
       
   338   Method.insert_tac facts 1 THEN barith_tac i)
       
   339 
       
   340 val setup =
       
   341   [Method.add_method ("barith",
       
   342      Method.no_args (barith_method 1),
       
   343      "VERY simple decision procedure for bounded arithmetic")];
       
   344 
       
   345 
       
   346 (* End of Structure *)
       
   347 end;
       
   348 
       
   349 (* Test *)
       
   350 (*
       
   351 open Barith;
       
   352 
       
   353 Goal "-1 <= (x::int) & x <= 1 ==> 0 <= (y::int) & y <= 5 + 7 ==> -13 <= x*x + y*x & x*x + y*x <= 20";
       
   354 by(barith_tac 1);
       
   355 
       
   356 Goal "-1 <= (x::int) & x <= 1 ==> 0 <= (y::int) & y <= 5 + 7 ==> 0 <= x - x  + y & x - x  + y<= 12";
       
   357 by(barith_tac 1);
       
   358 
       
   359 Goal "-1 <= (x::int) & x <= 1 ==> 0 <= (y::int) & y <= 5 + 7 ==> 0 <= x - x  + x*x & x - x  + x*x<= 1";
       
   360 by(barith_tac 1);
       
   361 
       
   362 Goal "(x::int) <= 1& 1 <= x ==> 0 <= (y::int) & y <= 5 + 7 ==> 0 <= x - x  + x*x & x - x  + x*x<= 1";
       
   363 by(barith_tac 1);
       
   364 
       
   365 Goal "(x::int) <= 1& 1 <= x ==> (t::int) <= 8 ==>(x::int) <= 2& 0 <= x ==> 0 <= (y::int) & y <= 5 + 7 ==> 0 <= x - x  + x*x & x - x  + x*x<= 1";
       
   366 by(barith_tac 1);
       
   367 
       
   368 Goal "-1 <= (x::int) ==>  x <= 1 & 1 <= (z::int) ==> z <= 2+3 ==> 0 <= (y::int) & y <= 5 + 7 ==> -4 <= x - x  + x*x";
       
   369 by(Barith.barith_tac 1);
       
   370 
       
   371 Goal "[|(0::int) <= x & x <= 5 ; 0 <= (y::int) & y <= 7|]==> (0 <= x*x*x & x*x*x <= 125 ) & (0 <= x*x & x*x <= 100) & (0 <= x*x + x & x*x + x <= 30) & (0<= x*y & x*y <= 35)";
       
   372 by (barith_tac 1);
       
   373 *)
       
   374 
       
   375 
       
   376 (*
       
   377 val st = topthm();
       
   378 val sg = sign_of_thm st; 
       
   379 val g = BasisLibrary.List.nth (prems_of st, 0);
       
   380 val (ths,n) = simp_exp sg g;
       
   381 fun assm_tac n j = REPEAT_DETERM_N n ((assume_tac j) ORELSE (simple_arith_tac j));
       
   382 
       
   383 *)