src/HOL/arith_data.ML
changeset 9436 62bb04ab4b01
child 9593 b732997cfc11
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/arith_data.ML	Tue Jul 25 00:06:46 2000 +0200
     1.3 @@ -0,0 +1,429 @@
     1.4 +(*  Title:      HOL/arith_data.ML
     1.5 +    ID:         $Id$
     1.6 +    Author:     Markus Wenzel, Stefan Berghofer and Tobias Nipkow
     1.7 +
     1.8 +Various arithmetic proof procedures.
     1.9 +*)
    1.10 +
    1.11 +(*---------------------------------------------------------------------------*)
    1.12 +(* 1. Cancellation of common terms                                           *)
    1.13 +(*---------------------------------------------------------------------------*)
    1.14 +
    1.15 +signature ARITH_DATA =
    1.16 +sig
    1.17 +  val nat_cancel_sums_add: simproc list
    1.18 +  val nat_cancel_sums: simproc list
    1.19 +  val nat_cancel_factor: simproc list
    1.20 +  val nat_cancel: simproc list
    1.21 +end;
    1.22 +
    1.23 +structure ArithData: ARITH_DATA =
    1.24 +struct
    1.25 +
    1.26 +
    1.27 +(** abstract syntax of structure nat: 0, Suc, + **)
    1.28 +
    1.29 +(* mk_sum, mk_norm_sum *)
    1.30 +
    1.31 +val one = HOLogic.mk_nat 1;
    1.32 +val mk_plus = HOLogic.mk_binop "op +";
    1.33 +
    1.34 +fun mk_sum [] = HOLogic.zero
    1.35 +  | mk_sum [t] = t
    1.36 +  | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    1.37 +
    1.38 +(*normal form of sums: Suc (... (Suc (a + (b + ...))))*)
    1.39 +fun mk_norm_sum ts =
    1.40 +  let val (ones, sums) = partition (equal one) ts in
    1.41 +    funpow (length ones) HOLogic.mk_Suc (mk_sum sums)
    1.42 +  end;
    1.43 +
    1.44 +
    1.45 +(* dest_sum *)
    1.46 +
    1.47 +val dest_plus = HOLogic.dest_bin "op +" HOLogic.natT;
    1.48 +
    1.49 +fun dest_sum tm =
    1.50 +  if HOLogic.is_zero tm then []
    1.51 +  else
    1.52 +    (case try HOLogic.dest_Suc tm of
    1.53 +      Some t => one :: dest_sum t
    1.54 +    | None =>
    1.55 +        (case try dest_plus tm of
    1.56 +          Some (t, u) => dest_sum t @ dest_sum u
    1.57 +        | None => [tm]));
    1.58 +
    1.59 +
    1.60 +(** generic proof tools **)
    1.61 +
    1.62 +(* prove conversions *)
    1.63 +
    1.64 +val mk_eqv = HOLogic.mk_Trueprop o HOLogic.mk_eq;
    1.65 +
    1.66 +fun prove_conv expand_tac norm_tac sg (t, u) =
    1.67 +  mk_meta_eq (prove_goalw_cterm_nocheck [] (cterm_of sg (mk_eqv (t, u)))
    1.68 +    (K [expand_tac, norm_tac]))
    1.69 +  handle ERROR => error ("The error(s) above occurred while trying to prove " ^
    1.70 +    (string_of_cterm (cterm_of sg (mk_eqv (t, u)))));
    1.71 +
    1.72 +val subst_equals = prove_goal HOL.thy "[| t = s; u = t |] ==> u = s"
    1.73 +  (fn prems => [cut_facts_tac prems 1, SIMPSET' asm_simp_tac 1]);
    1.74 +
    1.75 +
    1.76 +(* rewriting *)
    1.77 +
    1.78 +fun simp_all rules = ALLGOALS (simp_tac (HOL_ss addsimps rules));
    1.79 +
    1.80 +val add_rules = [add_Suc, add_Suc_right, add_0, add_0_right];
    1.81 +val mult_rules = [mult_Suc, mult_Suc_right, mult_0, mult_0_right];
    1.82 +
    1.83 +
    1.84 +
    1.85 +(** cancel common summands **)
    1.86 +
    1.87 +structure Sum =
    1.88 +struct
    1.89 +  val mk_sum = mk_norm_sum;
    1.90 +  val dest_sum = dest_sum;
    1.91 +  val prove_conv = prove_conv;
    1.92 +  val norm_tac = simp_all add_rules THEN simp_all add_ac;
    1.93 +end;
    1.94 +
    1.95 +fun gen_uncancel_tac rule ct =
    1.96 +  rtac (instantiate' [] [None, Some ct] (rule RS subst_equals)) 1;
    1.97 +
    1.98 +
    1.99 +(* nat eq *)
   1.100 +
   1.101 +structure EqCancelSums = CancelSumsFun
   1.102 +(struct
   1.103 +  open Sum;
   1.104 +  val mk_bal = HOLogic.mk_eq;
   1.105 +  val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
   1.106 +  val uncancel_tac = gen_uncancel_tac add_left_cancel;
   1.107 +end);
   1.108 +
   1.109 +
   1.110 +(* nat less *)
   1.111 +
   1.112 +structure LessCancelSums = CancelSumsFun
   1.113 +(struct
   1.114 +  open Sum;
   1.115 +  val mk_bal = HOLogic.mk_binrel "op <";
   1.116 +  val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
   1.117 +  val uncancel_tac = gen_uncancel_tac add_left_cancel_less;
   1.118 +end);
   1.119 +
   1.120 +
   1.121 +(* nat le *)
   1.122 +
   1.123 +structure LeCancelSums = CancelSumsFun
   1.124 +(struct
   1.125 +  open Sum;
   1.126 +  val mk_bal = HOLogic.mk_binrel "op <=";
   1.127 +  val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
   1.128 +  val uncancel_tac = gen_uncancel_tac add_left_cancel_le;
   1.129 +end);
   1.130 +
   1.131 +
   1.132 +(* nat diff *)
   1.133 +
   1.134 +structure DiffCancelSums = CancelSumsFun
   1.135 +(struct
   1.136 +  open Sum;
   1.137 +  val mk_bal = HOLogic.mk_binop "op -";
   1.138 +  val dest_bal = HOLogic.dest_bin "op -" HOLogic.natT;
   1.139 +  val uncancel_tac = gen_uncancel_tac diff_cancel;
   1.140 +end);
   1.141 +
   1.142 +
   1.143 +
   1.144 +(** cancel common factor **)
   1.145 +
   1.146 +structure Factor =
   1.147 +struct
   1.148 +  val mk_sum = mk_norm_sum;
   1.149 +  val dest_sum = dest_sum;
   1.150 +  val prove_conv = prove_conv;
   1.151 +  val norm_tac = simp_all (add_rules @ mult_rules) THEN simp_all add_ac;
   1.152 +end;
   1.153 +
   1.154 +fun mk_cnat n = cterm_of (Theory.sign_of (the_context ())) (HOLogic.mk_nat n);
   1.155 +
   1.156 +fun gen_multiply_tac rule k =
   1.157 +  if k > 0 then
   1.158 +    rtac (instantiate' [] [None, Some (mk_cnat (k - 1))] (rule RS subst_equals)) 1
   1.159 +  else no_tac;
   1.160 +
   1.161 +
   1.162 +(* nat eq *)
   1.163 +
   1.164 +structure EqCancelFactor = CancelFactorFun
   1.165 +(struct
   1.166 +  open Factor;
   1.167 +  val mk_bal = HOLogic.mk_eq;
   1.168 +  val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
   1.169 +  val multiply_tac = gen_multiply_tac Suc_mult_cancel1;
   1.170 +end);
   1.171 +
   1.172 +
   1.173 +(* nat less *)
   1.174 +
   1.175 +structure LessCancelFactor = CancelFactorFun
   1.176 +(struct
   1.177 +  open Factor;
   1.178 +  val mk_bal = HOLogic.mk_binrel "op <";
   1.179 +  val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
   1.180 +  val multiply_tac = gen_multiply_tac Suc_mult_less_cancel1;
   1.181 +end);
   1.182 +
   1.183 +
   1.184 +(* nat le *)
   1.185 +
   1.186 +structure LeCancelFactor = CancelFactorFun
   1.187 +(struct
   1.188 +  open Factor;
   1.189 +  val mk_bal = HOLogic.mk_binrel "op <=";
   1.190 +  val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
   1.191 +  val multiply_tac = gen_multiply_tac Suc_mult_le_cancel1;
   1.192 +end);
   1.193 +
   1.194 +
   1.195 +
   1.196 +(** prepare nat_cancel simprocs **)
   1.197 +
   1.198 +fun prep_pat s = Thm.read_cterm (Theory.sign_of (the_context ())) (s, HOLogic.termT);
   1.199 +val prep_pats = map prep_pat;
   1.200 +
   1.201 +fun prep_simproc (name, pats, proc) = Simplifier.mk_simproc name pats proc;
   1.202 +
   1.203 +val eq_pats = prep_pats ["(l::nat) + m = n", "(l::nat) = m + n", "Suc m = n", "m = Suc n"];
   1.204 +val less_pats = prep_pats ["(l::nat) + m < n", "(l::nat) < m + n", "Suc m < n", "m < Suc n"];
   1.205 +val le_pats = prep_pats ["(l::nat) + m <= n", "(l::nat) <= m + n", "Suc m <= n", "m <= Suc n"];
   1.206 +val diff_pats = prep_pats ["((l::nat) + m) - n", "(l::nat) - (m + n)", "Suc m - n", "m - Suc n"];
   1.207 +
   1.208 +val nat_cancel_sums_add = map prep_simproc
   1.209 +  [("nateq_cancel_sums", eq_pats, EqCancelSums.proc),
   1.210 +   ("natless_cancel_sums", less_pats, LessCancelSums.proc),
   1.211 +   ("natle_cancel_sums", le_pats, LeCancelSums.proc)];
   1.212 +
   1.213 +val nat_cancel_sums = nat_cancel_sums_add @
   1.214 +  [prep_simproc("natdiff_cancel_sums", diff_pats, DiffCancelSums.proc)];
   1.215 +
   1.216 +val nat_cancel_factor = map prep_simproc
   1.217 +  [("nateq_cancel_factor", eq_pats, EqCancelFactor.proc),
   1.218 +   ("natless_cancel_factor", less_pats, LessCancelFactor.proc),
   1.219 +   ("natle_cancel_factor", le_pats, LeCancelFactor.proc)];
   1.220 +
   1.221 +val nat_cancel = nat_cancel_factor @ nat_cancel_sums;
   1.222 +
   1.223 +
   1.224 +end;
   1.225 +
   1.226 +open ArithData;
   1.227 +
   1.228 +
   1.229 +(*---------------------------------------------------------------------------*)
   1.230 +(* 2. Linear arithmetic                                                      *)
   1.231 +(*---------------------------------------------------------------------------*)
   1.232 +
   1.233 +(* Parameters data for general linear arithmetic functor *)
   1.234 +
   1.235 +structure LA_Logic: LIN_ARITH_LOGIC =
   1.236 +struct
   1.237 +val ccontr = ccontr;
   1.238 +val conjI = conjI;
   1.239 +val neqE = linorder_neqE;
   1.240 +val notI = notI;
   1.241 +val sym = sym;
   1.242 +val not_lessD = linorder_not_less RS iffD1;
   1.243 +val not_leD = linorder_not_le RS iffD1;
   1.244 +
   1.245 +
   1.246 +fun mk_Eq thm = (thm RS Eq_FalseI) handle THM _ => (thm RS Eq_TrueI);
   1.247 +
   1.248 +val mk_Trueprop = HOLogic.mk_Trueprop;
   1.249 +
   1.250 +fun neg_prop(TP$(Const("Not",_)$t)) = TP$t
   1.251 +  | neg_prop(TP$t) = TP $ (Const("Not",HOLogic.boolT-->HOLogic.boolT)$t);
   1.252 +
   1.253 +fun is_False thm =
   1.254 +  let val _ $ t = #prop(rep_thm thm)
   1.255 +  in t = Const("False",HOLogic.boolT) end;
   1.256 +
   1.257 +fun is_nat(t) = fastype_of1 t = HOLogic.natT;
   1.258 +
   1.259 +fun mk_nat_thm sg t =
   1.260 +  let val ct = cterm_of sg t  and cn = cterm_of sg (Var(("n",0),HOLogic.natT))
   1.261 +  in instantiate ([],[(cn,ct)]) le0 end;
   1.262 +
   1.263 +end;
   1.264 +
   1.265 +
   1.266 +(* arith theory data *)
   1.267 +
   1.268 +structure ArithDataArgs =
   1.269 +struct
   1.270 +  val name = "HOL/arith";
   1.271 +  type T = {splits: thm list, discrete: (string * bool) list};
   1.272 +
   1.273 +  val empty = {splits = [], discrete = []};
   1.274 +  val copy = I;
   1.275 +  val prep_ext = I;
   1.276 +  fun merge ({splits = splits1, discrete = discrete1}, {splits = splits2, discrete = discrete2}) =
   1.277 +   {splits = Drule.merge_rules (splits1, splits2),
   1.278 +    discrete = merge_alists discrete1 discrete2};
   1.279 +  fun print _ _ = ();
   1.280 +end;
   1.281 +
   1.282 +structure ArithData = TheoryDataFun(ArithDataArgs);
   1.283 +
   1.284 +fun arith_split_add (thy, thm) = (ArithData.map (fn {splits, discrete} =>
   1.285 +  {splits = thm :: splits, discrete = discrete}) thy, thm);
   1.286 +
   1.287 +fun arith_discrete d = ArithData.map (fn {splits, discrete} =>
   1.288 +  {splits = splits, discrete = d :: discrete});
   1.289 +
   1.290 +
   1.291 +structure LA_Data_Ref: LIN_ARITH_DATA =
   1.292 +struct
   1.293 +
   1.294 +(* Decomposition of terms *)
   1.295 +
   1.296 +fun nT (Type("fun",[N,_])) = N = HOLogic.natT
   1.297 +  | nT _ = false;
   1.298 +
   1.299 +fun add_atom(t,m,(p,i)) = (case assoc(p,t) of None => ((t,m)::p,i)
   1.300 +                           | Some n => (overwrite(p,(t,n+m:int)), i));
   1.301 +
   1.302 +(* Turn term into list of summand * multiplicity plus a constant *)
   1.303 +fun poly(Const("op +",_) $ s $ t, m, pi) = poly(s,m,poly(t,m,pi))
   1.304 +  | poly(all as Const("op -",T) $ s $ t, m, pi) =
   1.305 +      if nT T then add_atom(all,m,pi)
   1.306 +      else poly(s,m,poly(t,~1*m,pi))
   1.307 +  | poly(Const("uminus",_) $ t, m, pi) = poly(t,~1*m,pi)
   1.308 +  | poly(Const("0",_), _, pi) = pi
   1.309 +  | poly(Const("Suc",_)$t, m, (p,i)) = poly(t, m, (p,i+m))
   1.310 +  | poly(all as Const("op *",_) $ (Const("Numeral.number_of",_)$c) $ t, m, pi)=
   1.311 +      (poly(t,m*HOLogic.dest_binum c,pi)
   1.312 +       handle TERM _ => add_atom(all,m,pi))
   1.313 +  | poly(all as Const("op *",_) $ t $ (Const("Numeral.number_of",_)$c), m, pi)=
   1.314 +      (poly(t,m*HOLogic.dest_binum c,pi)
   1.315 +       handle TERM _ => add_atom(all,m,pi))
   1.316 +  | poly(all as Const("Numeral.number_of",_)$t,m,(p,i)) =
   1.317 +     ((p,i + m*HOLogic.dest_binum t)
   1.318 +      handle TERM _ => add_atom(all,m,(p,i)))
   1.319 +  | poly x  = add_atom x;
   1.320 +
   1.321 +fun decomp2(rel,lhs,rhs) =
   1.322 +  let val (p,i) = poly(lhs,1,([],0)) and (q,j) = poly(rhs,1,([],0))
   1.323 +  in case rel of
   1.324 +       "op <"  => Some(p,i,"<",q,j)
   1.325 +     | "op <=" => Some(p,i,"<=",q,j)
   1.326 +     | "op ="  => Some(p,i,"=",q,j)
   1.327 +     | _       => None
   1.328 +  end;
   1.329 +
   1.330 +fun negate(Some(x,i,rel,y,j,d)) = Some(x,i,"~"^rel,y,j,d)
   1.331 +  | negate None = None;
   1.332 +
   1.333 +fun decomp1 discrete (T,xxx) =
   1.334 +  (case T of
   1.335 +     Type("fun",[Type(D,[]),_]) =>
   1.336 +       (case assoc(discrete,D) of
   1.337 +          None => None
   1.338 +        | Some d => (case decomp2 xxx of
   1.339 +                       None => None
   1.340 +                     | Some(p,i,rel,q,j) => Some(p,i,rel,q,j,d)))
   1.341 +   | _ => None);
   1.342 +
   1.343 +fun decomp2 discrete (_$(Const(rel,T)$lhs$rhs)) = decomp1 discrete (T,(rel,lhs,rhs))
   1.344 +  | decomp2 discrete (_$(Const("Not",_)$(Const(rel,T)$lhs$rhs))) =
   1.345 +      negate(decomp1 discrete (T,(rel,lhs,rhs)))
   1.346 +  | decomp2 discrete _ = None
   1.347 +
   1.348 +val decomp = decomp2 o #discrete o ArithData.get_sg;
   1.349 +
   1.350 +end;
   1.351 +
   1.352 +
   1.353 +structure Fast_Arith =
   1.354 +  Fast_Lin_Arith(structure LA_Logic=LA_Logic and LA_Data=LA_Data_Ref);
   1.355 +
   1.356 +val fast_arith_tac = Fast_Arith.lin_arith_tac
   1.357 +and trace_arith    = Fast_Arith.trace;
   1.358 +
   1.359 +local
   1.360 +
   1.361 +(* reduce contradictory <= to False.
   1.362 +   Most of the work is done by the cancel tactics.
   1.363 +*)
   1.364 +val add_rules = [add_0,add_0_right,Zero_not_Suc,Suc_not_Zero,le_0_eq];
   1.365 +
   1.366 +val add_mono_thms_nat = map (fn s => prove_goal (the_context ()) s
   1.367 + (fn prems => [cut_facts_tac prems 1,
   1.368 +               blast_tac (claset() addIs [add_le_mono]) 1]))
   1.369 +["(i <= j) & (k <= l) ==> i + k <= j + (l::nat)",
   1.370 + "(i  = j) & (k <= l) ==> i + k <= j + (l::nat)",
   1.371 + "(i <= j) & (k  = l) ==> i + k <= j + (l::nat)",
   1.372 + "(i  = j) & (k  = l) ==> i + k  = j + (l::nat)"
   1.373 +];
   1.374 +
   1.375 +in
   1.376 +
   1.377 +val init_lin_arith_data =
   1.378 + Fast_Arith.setup @
   1.379 + [Fast_Arith.map_data (fn {add_mono_thms, lessD, simpset = _} =>
   1.380 +   {add_mono_thms = add_mono_thms @ add_mono_thms_nat,
   1.381 +    lessD = lessD @ [Suc_leI],
   1.382 +    simpset = HOL_basic_ss addsimps add_rules addsimprocs nat_cancel_sums_add}),
   1.383 +  ArithData.init, arith_discrete ("nat", true)];
   1.384 +
   1.385 +end;
   1.386 +
   1.387 +
   1.388 +local
   1.389 +val nat_arith_simproc_pats =
   1.390 +  map (fn s => Thm.read_cterm (Theory.sign_of (the_context ())) (s, HOLogic.boolT))
   1.391 +      ["(m::nat) < n","(m::nat) <= n", "(m::nat) = n"];
   1.392 +in
   1.393 +val fast_nat_arith_simproc = mk_simproc
   1.394 +  "fast_nat_arith" nat_arith_simproc_pats Fast_Arith.lin_arith_prover;
   1.395 +end;
   1.396 +
   1.397 +(* Because of fast_nat_arith_simproc, the arithmetic solver is really only
   1.398 +useful to detect inconsistencies among the premises for subgoals which are
   1.399 +*not* themselves (in)equalities, because the latter activate
   1.400 +fast_nat_arith_simproc anyway. However, it seems cheaper to activate the
   1.401 +solver all the time rather than add the additional check. *)
   1.402 +
   1.403 +
   1.404 +(* arith proof method *)
   1.405 +
   1.406 +(* FIXME: K true should be replaced by a sensible test to speed things up
   1.407 +   in case there are lots of irrelevant terms involved;
   1.408 +   elimination of min/max can be optimized:
   1.409 +   (max m n + k <= r) = (m+k <= r & n+k <= r)
   1.410 +   (l <= min m n + k) = (l <= m+k & l <= n+k)
   1.411 +*)
   1.412 +fun arith_tac i st =
   1.413 +  refute_tac (K true) (REPEAT o split_tac (#splits (ArithData.get_sg (Thm.sign_of_thm st))))
   1.414 +             ((REPEAT_DETERM o etac linorder_neqE) THEN' fast_arith_tac) i st;
   1.415 +
   1.416 +fun arith_method prems =
   1.417 +  Method.METHOD (fn facts => HEADGOAL (Method.insert_tac (prems @ facts) THEN' arith_tac));
   1.418 +
   1.419 +
   1.420 +(* theory setup *)
   1.421 +
   1.422 +val arith_setup =
   1.423 + [Simplifier.change_simpset_of (op addsimprocs) nat_cancel] @
   1.424 +  init_lin_arith_data @
   1.425 +  [Simplifier.change_simpset_of (op addSolver)
   1.426 +   (mk_solver "lin. arith." Fast_Arith.cut_lin_arith_tac),
   1.427 +  Simplifier.change_simpset_of (op addsimprocs) [fast_nat_arith_simproc],
   1.428 +  Method.add_methods [("arith", (arith_method o #2) oo Method.syntax Args.bang_facts,
   1.429 +    "decide linear arithmethic")],
   1.430 +  Attrib.add_attributes [("arith_split",
   1.431 +    (Attrib.no_args arith_split_add, Attrib.no_args Attrib.undef_local_attribute),
   1.432 +    "declare split rules for arithmetic procedure")]];