src/HOL/arith_data.ML
author nipkow
Fri Dec 01 19:53:29 2000 +0100 (2000-12-01)
changeset 10574 8f98f0301d67
parent 10516 dc113303d101
child 10693 9e4a0e84d0d6
permissions -rw-r--r--
Linear arithmetic now copes with mixed nat/int formulae.
     1 (*  Title:      HOL/arith_data.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel, Stefan Berghofer and Tobias Nipkow
     4 
     5 Various arithmetic proof procedures.
     6 *)
     7 
     8 (*---------------------------------------------------------------------------*)
     9 (* 1. Cancellation of common terms                                           *)
    10 (*---------------------------------------------------------------------------*)
    11 
    12 signature ARITH_DATA =
    13 sig
    14   val nat_cancel_sums_add: simproc list
    15   val nat_cancel_sums: simproc list
    16   val nat_cancel_factor: simproc list
    17   val nat_cancel: simproc list
    18 end;
    19 
    20 structure ArithData: ARITH_DATA =
    21 struct
    22 
    23 
    24 (** abstract syntax of structure nat: 0, Suc, + **)
    25 
    26 (* mk_sum, mk_norm_sum *)
    27 
    28 val one = HOLogic.mk_nat 1;
    29 val mk_plus = HOLogic.mk_binop "op +";
    30 
    31 fun mk_sum [] = HOLogic.zero
    32   | mk_sum [t] = t
    33   | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    34 
    35 (*normal form of sums: Suc (... (Suc (a + (b + ...))))*)
    36 fun mk_norm_sum ts =
    37   let val (ones, sums) = partition (equal one) ts in
    38     funpow (length ones) HOLogic.mk_Suc (mk_sum sums)
    39   end;
    40 
    41 
    42 (* dest_sum *)
    43 
    44 val dest_plus = HOLogic.dest_bin "op +" HOLogic.natT;
    45 
    46 fun dest_sum tm =
    47   if HOLogic.is_zero tm then []
    48   else
    49     (case try HOLogic.dest_Suc tm of
    50       Some t => one :: dest_sum t
    51     | None =>
    52         (case try dest_plus tm of
    53           Some (t, u) => dest_sum t @ dest_sum u
    54         | None => [tm]));
    55 
    56 
    57 (** generic proof tools **)
    58 
    59 (* prove conversions *)
    60 
    61 val mk_eqv = HOLogic.mk_Trueprop o HOLogic.mk_eq;
    62 
    63 fun prove_conv expand_tac norm_tac sg (t, u) =
    64   mk_meta_eq (prove_goalw_cterm_nocheck [] (cterm_of sg (mk_eqv (t, u)))
    65     (K [expand_tac, norm_tac]))
    66   handle ERROR => error ("The error(s) above occurred while trying to prove " ^
    67     (string_of_cterm (cterm_of sg (mk_eqv (t, u)))));
    68 
    69 val subst_equals = prove_goal HOL.thy "[| t = s; u = t |] ==> u = s"
    70   (fn prems => [cut_facts_tac prems 1, SIMPSET' asm_simp_tac 1]);
    71 
    72 
    73 (* rewriting *)
    74 
    75 fun simp_all rules = ALLGOALS (simp_tac (HOL_ss addsimps rules));
    76 
    77 val add_rules = [add_Suc, add_Suc_right, add_0, add_0_right];
    78 val mult_rules = [mult_Suc, mult_Suc_right, mult_0, mult_0_right];
    79 
    80 
    81 
    82 (** cancel common summands **)
    83 
    84 structure Sum =
    85 struct
    86   val mk_sum = mk_norm_sum;
    87   val dest_sum = dest_sum;
    88   val prove_conv = prove_conv;
    89   val norm_tac = simp_all add_rules THEN simp_all add_ac;
    90 end;
    91 
    92 fun gen_uncancel_tac rule ct =
    93   rtac (instantiate' [] [None, Some ct] (rule RS subst_equals)) 1;
    94 
    95 
    96 (* nat eq *)
    97 
    98 structure EqCancelSums = CancelSumsFun
    99 (struct
   100   open Sum;
   101   val mk_bal = HOLogic.mk_eq;
   102   val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
   103   val uncancel_tac = gen_uncancel_tac add_left_cancel;
   104 end);
   105 
   106 
   107 (* nat less *)
   108 
   109 structure LessCancelSums = CancelSumsFun
   110 (struct
   111   open Sum;
   112   val mk_bal = HOLogic.mk_binrel "op <";
   113   val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
   114   val uncancel_tac = gen_uncancel_tac add_left_cancel_less;
   115 end);
   116 
   117 
   118 (* nat le *)
   119 
   120 structure LeCancelSums = CancelSumsFun
   121 (struct
   122   open Sum;
   123   val mk_bal = HOLogic.mk_binrel "op <=";
   124   val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
   125   val uncancel_tac = gen_uncancel_tac add_left_cancel_le;
   126 end);
   127 
   128 
   129 (* nat diff *)
   130 
   131 structure DiffCancelSums = CancelSumsFun
   132 (struct
   133   open Sum;
   134   val mk_bal = HOLogic.mk_binop "op -";
   135   val dest_bal = HOLogic.dest_bin "op -" HOLogic.natT;
   136   val uncancel_tac = gen_uncancel_tac diff_cancel;
   137 end);
   138 
   139 
   140 
   141 (** cancel common factor **)
   142 
   143 structure Factor =
   144 struct
   145   val mk_sum = mk_norm_sum;
   146   val dest_sum = dest_sum;
   147   val prove_conv = prove_conv;
   148   val norm_tac = simp_all (add_rules @ mult_rules) THEN simp_all add_ac;
   149 end;
   150 
   151 fun mk_cnat n = cterm_of (Theory.sign_of (the_context ())) (HOLogic.mk_nat n);
   152 
   153 fun gen_multiply_tac rule k =
   154   if k > 0 then
   155     rtac (instantiate' [] [None, Some (mk_cnat (k - 1))] (rule RS subst_equals)) 1
   156   else no_tac;
   157 
   158 
   159 (* nat eq *)
   160 
   161 structure EqCancelFactor = CancelFactorFun
   162 (struct
   163   open Factor;
   164   val mk_bal = HOLogic.mk_eq;
   165   val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
   166   val multiply_tac = gen_multiply_tac Suc_mult_cancel1;
   167 end);
   168 
   169 
   170 (* nat less *)
   171 
   172 structure LessCancelFactor = CancelFactorFun
   173 (struct
   174   open Factor;
   175   val mk_bal = HOLogic.mk_binrel "op <";
   176   val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
   177   val multiply_tac = gen_multiply_tac Suc_mult_less_cancel1;
   178 end);
   179 
   180 
   181 (* nat le *)
   182 
   183 structure LeCancelFactor = CancelFactorFun
   184 (struct
   185   open Factor;
   186   val mk_bal = HOLogic.mk_binrel "op <=";
   187   val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
   188   val multiply_tac = gen_multiply_tac Suc_mult_le_cancel1;
   189 end);
   190 
   191 
   192 
   193 (** prepare nat_cancel simprocs **)
   194 
   195 fun prep_pat s = Thm.read_cterm (Theory.sign_of (the_context ())) (s, HOLogic.termT);
   196 val prep_pats = map prep_pat;
   197 
   198 fun prep_simproc (name, pats, proc) = Simplifier.mk_simproc name pats proc;
   199 
   200 val eq_pats = prep_pats ["(l::nat) + m = n", "(l::nat) = m + n", "Suc m = n", "m = Suc n"];
   201 val less_pats = prep_pats ["(l::nat) + m < n", "(l::nat) < m + n", "Suc m < n", "m < Suc n"];
   202 val le_pats = prep_pats ["(l::nat) + m <= n", "(l::nat) <= m + n", "Suc m <= n", "m <= Suc n"];
   203 val diff_pats = prep_pats ["((l::nat) + m) - n", "(l::nat) - (m + n)", "Suc m - n", "m - Suc n"];
   204 
   205 val nat_cancel_sums_add = map prep_simproc
   206   [("nateq_cancel_sums", eq_pats, EqCancelSums.proc),
   207    ("natless_cancel_sums", less_pats, LessCancelSums.proc),
   208    ("natle_cancel_sums", le_pats, LeCancelSums.proc)];
   209 
   210 val nat_cancel_sums = nat_cancel_sums_add @
   211   [prep_simproc("natdiff_cancel_sums", diff_pats, DiffCancelSums.proc)];
   212 
   213 val nat_cancel_factor = map prep_simproc
   214   [("nateq_cancel_factor", eq_pats, EqCancelFactor.proc),
   215    ("natless_cancel_factor", less_pats, LessCancelFactor.proc),
   216    ("natle_cancel_factor", le_pats, LeCancelFactor.proc)];
   217 
   218 val nat_cancel = nat_cancel_factor @ nat_cancel_sums;
   219 
   220 
   221 end;
   222 
   223 open ArithData;
   224 
   225 
   226 (*---------------------------------------------------------------------------*)
   227 (* 2. Linear arithmetic                                                      *)
   228 (*---------------------------------------------------------------------------*)
   229 
   230 (* Parameters data for general linear arithmetic functor *)
   231 
   232 structure LA_Logic: LIN_ARITH_LOGIC =
   233 struct
   234 val ccontr = ccontr;
   235 val conjI = conjI;
   236 val neqE = linorder_neqE;
   237 val notI = notI;
   238 val sym = sym;
   239 val not_lessD = linorder_not_less RS iffD1;
   240 val not_leD = linorder_not_le RS iffD1;
   241 
   242 
   243 fun mk_Eq thm = (thm RS Eq_FalseI) handle THM _ => (thm RS Eq_TrueI);
   244 
   245 val mk_Trueprop = HOLogic.mk_Trueprop;
   246 
   247 fun neg_prop(TP$(Const("Not",_)$t)) = TP$t
   248   | neg_prop(TP$t) = TP $ (Const("Not",HOLogic.boolT-->HOLogic.boolT)$t);
   249 
   250 fun is_False thm =
   251   let val _ $ t = #prop(rep_thm thm)
   252   in t = Const("False",HOLogic.boolT) end;
   253 
   254 fun is_nat(t) = fastype_of1 t = HOLogic.natT;
   255 
   256 fun mk_nat_thm sg t =
   257   let val ct = cterm_of sg t  and cn = cterm_of sg (Var(("n",0),HOLogic.natT))
   258   in instantiate ([],[(cn,ct)]) le0 end;
   259 
   260 end;
   261 
   262 
   263 (* arith theory data *)
   264 
   265 structure ArithTheoryDataArgs =
   266 struct
   267   val name = "HOL/arith";
   268   type T = {splits: thm list, inj_consts: (string * typ)list, discrete: (string * bool) list};
   269 
   270   val empty = {splits = [], inj_consts = [], discrete = []};
   271   val copy = I;
   272   val prep_ext = I;
   273   fun merge ({splits= splits1, inj_consts= inj_consts1, discrete= discrete1},
   274              {splits= splits2, inj_consts= inj_consts2, discrete= discrete2}) =
   275    {splits = Drule.merge_rules (splits1, splits2),
   276     inj_consts = merge_lists inj_consts1 inj_consts2,
   277     discrete = merge_alists discrete1 discrete2};
   278   fun print _ _ = ();
   279 end;
   280 
   281 structure ArithTheoryData = TheoryDataFun(ArithTheoryDataArgs);
   282 
   283 fun arith_split_add (thy, thm) = (ArithTheoryData.map (fn {splits,inj_consts,discrete} =>
   284   {splits= thm::splits, inj_consts= inj_consts, discrete= discrete}) thy, thm);
   285 
   286 fun arith_discrete d = ArithTheoryData.map (fn {splits,inj_consts,discrete} =>
   287   {splits = splits, inj_consts = inj_consts, discrete = d :: discrete});
   288 
   289 fun arith_inj_const c = ArithTheoryData.map (fn {splits,inj_consts,discrete} =>
   290   {splits = splits, inj_consts = c :: inj_consts, discrete = discrete});
   291 
   292 
   293 structure LA_Data_Ref: LIN_ARITH_DATA =
   294 struct
   295 
   296 (* Decomposition of terms *)
   297 
   298 fun nT (Type("fun",[N,_])) = N = HOLogic.natT
   299   | nT _ = false;
   300 
   301 fun add_atom(t,m,(p,i)) = (case assoc(p,t) of None => ((t,m)::p,i)
   302                            | Some n => (overwrite(p,(t,n+m:int)), i));
   303 
   304 fun decomp2 inj_consts (rel,lhs,rhs) =
   305 
   306 let
   307 (* Turn term into list of summand * multiplicity plus a constant *)
   308 fun poly(Const("op +",_) $ s $ t, m, pi) = poly(s,m,poly(t,m,pi))
   309   | poly(all as Const("op -",T) $ s $ t, m, pi) =
   310       if nT T then add_atom(all,m,pi)
   311       else poly(s,m,poly(t,~1*m,pi))
   312   | poly(Const("uminus",_) $ t, m, pi) = poly(t,~1*m,pi)
   313   | poly(Const("0",_), _, pi) = pi
   314   | poly(Const("Suc",_)$t, m, (p,i)) = poly(t, m, (p,i+m))
   315   | poly(all as Const("op *",_) $ (Const("Numeral.number_of",_)$c) $ t, m, pi)=
   316       (poly(t,m*HOLogic.dest_binum c,pi)
   317        handle TERM _ => add_atom(all,m,pi))
   318   | poly(all as Const("op *",_) $ t $ (Const("Numeral.number_of",_)$c), m, pi)=
   319       (poly(t,m*HOLogic.dest_binum c,pi)
   320        handle TERM _ => add_atom(all,m,pi))
   321   | poly(all as Const("Numeral.number_of",_)$t,m,(p,i)) =
   322      ((p,i + m*HOLogic.dest_binum t)
   323       handle TERM _ => add_atom(all,m,(p,i)))
   324   | poly(all as Const f $ x, m, pi) =
   325       if f mem inj_consts then poly(x,m,pi) else add_atom(all,m,pi)
   326   | poly x  = add_atom x;
   327 
   328   val (p,i) = poly(lhs,1,([],0)) and (q,j) = poly(rhs,1,([],0))
   329   in case rel of
   330        "op <"  => Some(p,i,"<",q,j)
   331      | "op <=" => Some(p,i,"<=",q,j)
   332      | "op ="  => Some(p,i,"=",q,j)
   333      | _       => None
   334   end;
   335 
   336 fun negate(Some(x,i,rel,y,j,d)) = Some(x,i,"~"^rel,y,j,d)
   337   | negate None = None;
   338 
   339 fun decomp1 (discrete,inj_consts) (T,xxx) =
   340   (case T of
   341      Type("fun",[Type(D,[]),_]) =>
   342        (case assoc(discrete,D) of
   343           None => None
   344         | Some d => (case decomp2 inj_consts xxx of
   345                        None => None
   346                      | Some(p,i,rel,q,j) => Some(p,i,rel,q,j,d)))
   347    | _ => None);
   348 
   349 fun decomp2 data (_$(Const(rel,T)$lhs$rhs)) = decomp1 data (T,(rel,lhs,rhs))
   350   | decomp2 data (_$(Const("Not",_)$(Const(rel,T)$lhs$rhs))) =
   351       negate(decomp1 data (T,(rel,lhs,rhs)))
   352   | decomp2 data _ = None
   353 
   354 fun decomp sg =
   355   let val {discrete, inj_consts, ...} = ArithTheoryData.get_sg sg
   356   in decomp2 (discrete,inj_consts) end
   357 
   358 end;
   359 
   360 
   361 structure Fast_Arith =
   362   Fast_Lin_Arith(structure LA_Logic=LA_Logic and LA_Data=LA_Data_Ref);
   363 
   364 val fast_arith_tac = Fast_Arith.lin_arith_tac
   365 and trace_arith    = Fast_Arith.trace;
   366 
   367 local
   368 
   369 (* reduce contradictory <= to False.
   370    Most of the work is done by the cancel tactics.
   371 *)
   372 val add_rules = [add_0,add_0_right,Zero_not_Suc,Suc_not_Zero,le_0_eq];
   373 
   374 val add_mono_thms_nat = map (fn s => prove_goal (the_context ()) s
   375  (fn prems => [cut_facts_tac prems 1,
   376                blast_tac (claset() addIs [add_le_mono]) 1]))
   377 ["(i <= j) & (k <= l) ==> i + k <= j + (l::nat)",
   378  "(i  = j) & (k <= l) ==> i + k <= j + (l::nat)",
   379  "(i <= j) & (k  = l) ==> i + k <= j + (l::nat)",
   380  "(i  = j) & (k  = l) ==> i + k  = j + (l::nat)"
   381 ];
   382 
   383 in
   384 
   385 val init_lin_arith_data =
   386  Fast_Arith.setup @
   387  [Fast_Arith.map_data (fn {add_mono_thms, inj_thms, lessD, simpset = _} =>
   388    {add_mono_thms = add_mono_thms @ add_mono_thms_nat,
   389     inj_thms = inj_thms,
   390     lessD = lessD @ [Suc_leI],
   391     simpset = HOL_basic_ss addsimps add_rules addsimprocs nat_cancel_sums_add}),
   392   ArithTheoryData.init, arith_discrete ("nat", true)];
   393 
   394 end;
   395 
   396 
   397 local
   398 val nat_arith_simproc_pats =
   399   map (fn s => Thm.read_cterm (Theory.sign_of (the_context ())) (s, HOLogic.boolT))
   400       ["(m::nat) < n","(m::nat) <= n", "(m::nat) = n"];
   401 in
   402 val fast_nat_arith_simproc = mk_simproc
   403   "fast_nat_arith" nat_arith_simproc_pats Fast_Arith.lin_arith_prover;
   404 end;
   405 
   406 (* Because of fast_nat_arith_simproc, the arithmetic solver is really only
   407 useful to detect inconsistencies among the premises for subgoals which are
   408 *not* themselves (in)equalities, because the latter activate
   409 fast_nat_arith_simproc anyway. However, it seems cheaper to activate the
   410 solver all the time rather than add the additional check. *)
   411 
   412 
   413 (* arith proof method *)
   414 
   415 (* FIXME: K true should be replaced by a sensible test to speed things up
   416    in case there are lots of irrelevant terms involved;
   417    elimination of min/max can be optimized:
   418    (max m n + k <= r) = (m+k <= r & n+k <= r)
   419    (l <= min m n + k) = (l <= m+k & l <= n+k)
   420 *)
   421 local
   422 
   423 val atomize_tac = Method.atomize_tac (thms "atomize'");
   424 
   425 fun raw_arith_tac i st =
   426   refute_tac (K true) (REPEAT o split_tac (#splits (ArithTheoryData.get_sg (Thm.sign_of_thm st))))
   427              ((REPEAT_DETERM o etac linorder_neqE) THEN' fast_arith_tac) i st;
   428 
   429 in
   430 
   431 val arith_tac = fast_arith_tac ORELSE' (atomize_tac THEN' raw_arith_tac);
   432 
   433 fun arith_method prems =
   434   Method.METHOD (fn facts => HEADGOAL (Method.insert_tac (prems @ facts) THEN' arith_tac));
   435 
   436 end;
   437 
   438 
   439 (* theory setup *)
   440 
   441 val arith_setup =
   442  [Simplifier.change_simpset_of (op addsimprocs) nat_cancel] @
   443   init_lin_arith_data @
   444   [Simplifier.change_simpset_of (op addSolver)
   445    (mk_solver "lin. arith." Fast_Arith.cut_lin_arith_tac),
   446   Simplifier.change_simpset_of (op addsimprocs) [fast_nat_arith_simproc],
   447   Method.add_methods [("arith", (arith_method o #2) oo Method.syntax Args.bang_facts,
   448     "decide linear arithmethic")],
   449   Attrib.add_attributes [("arith_split",
   450     (Attrib.no_args arith_split_add, Attrib.no_args Attrib.undef_local_attribute),
   451     "declaration of split rules for arithmetic procedure")]];