src/HOL/arith_data.ML
author wenzelm
Tue Jan 16 00:28:50 2001 +0100 (2001-01-16)
changeset 10906 de95ba2760fe
parent 10766 ace2ba2d4fd1
child 11334 a16eaf2a1edd
permissions -rw-r--r--
tuned atomize;
     1 (*  Title:      HOL/arith_data.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel, Stefan Berghofer and Tobias Nipkow
     4 
     5 Various arithmetic proof procedures.
     6 *)
     7 
     8 (*---------------------------------------------------------------------------*)
     9 (* 1. Cancellation of common terms                                           *)
    10 (*---------------------------------------------------------------------------*)
    11 
    12 signature ARITH_DATA =
    13 sig
    14   val nat_cancel_sums_add: simproc list
    15   val nat_cancel_sums: simproc list
    16 end;
    17 
    18 structure ArithData: ARITH_DATA =
    19 struct
    20 
    21 
    22 (** abstract syntax of structure nat: 0, Suc, + **)
    23 
    24 (* mk_sum, mk_norm_sum *)
    25 
    26 val one = HOLogic.mk_nat 1;
    27 val mk_plus = HOLogic.mk_binop "op +";
    28 
    29 fun mk_sum [] = HOLogic.zero
    30   | mk_sum [t] = t
    31   | mk_sum (t :: ts) = mk_plus (t, mk_sum ts);
    32 
    33 (*normal form of sums: Suc (... (Suc (a + (b + ...))))*)
    34 fun mk_norm_sum ts =
    35   let val (ones, sums) = partition (equal one) ts in
    36     funpow (length ones) HOLogic.mk_Suc (mk_sum sums)
    37   end;
    38 
    39 
    40 (* dest_sum *)
    41 
    42 val dest_plus = HOLogic.dest_bin "op +" HOLogic.natT;
    43 
    44 fun dest_sum tm =
    45   if HOLogic.is_zero tm then []
    46   else
    47     (case try HOLogic.dest_Suc tm of
    48       Some t => one :: dest_sum t
    49     | None =>
    50         (case try dest_plus tm of
    51           Some (t, u) => dest_sum t @ dest_sum u
    52         | None => [tm]));
    53 
    54 
    55 (** generic proof tools **)
    56 
    57 (* prove conversions *)
    58 
    59 val mk_eqv = HOLogic.mk_Trueprop o HOLogic.mk_eq;
    60 
    61 fun prove_conv expand_tac norm_tac sg (t, u) =
    62   mk_meta_eq (prove_goalw_cterm_nocheck [] (cterm_of sg (mk_eqv (t, u)))
    63     (K [expand_tac, norm_tac]))
    64   handle ERROR => error ("The error(s) above occurred while trying to prove " ^
    65     (string_of_cterm (cterm_of sg (mk_eqv (t, u)))));
    66 
    67 val subst_equals = prove_goal HOL.thy "[| t = s; u = t |] ==> u = s"
    68   (fn prems => [cut_facts_tac prems 1, SIMPSET' asm_simp_tac 1]);
    69 
    70 
    71 (* rewriting *)
    72 
    73 fun simp_all rules = ALLGOALS (simp_tac (HOL_ss addsimps rules));
    74 
    75 val add_rules = [add_Suc, add_Suc_right, add_0, add_0_right];
    76 val mult_rules = [mult_Suc, mult_Suc_right, mult_0, mult_0_right];
    77 
    78 
    79 
    80 (** cancel common summands **)
    81 
    82 structure Sum =
    83 struct
    84   val mk_sum = mk_norm_sum;
    85   val dest_sum = dest_sum;
    86   val prove_conv = prove_conv;
    87   val norm_tac = simp_all add_rules THEN simp_all add_ac;
    88 end;
    89 
    90 fun gen_uncancel_tac rule ct =
    91   rtac (instantiate' [] [None, Some ct] (rule RS subst_equals)) 1;
    92 
    93 
    94 (* nat eq *)
    95 
    96 structure EqCancelSums = CancelSumsFun
    97 (struct
    98   open Sum;
    99   val mk_bal = HOLogic.mk_eq;
   100   val dest_bal = HOLogic.dest_bin "op =" HOLogic.natT;
   101   val uncancel_tac = gen_uncancel_tac add_left_cancel;
   102 end);
   103 
   104 
   105 (* nat less *)
   106 
   107 structure LessCancelSums = CancelSumsFun
   108 (struct
   109   open Sum;
   110   val mk_bal = HOLogic.mk_binrel "op <";
   111   val dest_bal = HOLogic.dest_bin "op <" HOLogic.natT;
   112   val uncancel_tac = gen_uncancel_tac add_left_cancel_less;
   113 end);
   114 
   115 
   116 (* nat le *)
   117 
   118 structure LeCancelSums = CancelSumsFun
   119 (struct
   120   open Sum;
   121   val mk_bal = HOLogic.mk_binrel "op <=";
   122   val dest_bal = HOLogic.dest_bin "op <=" HOLogic.natT;
   123   val uncancel_tac = gen_uncancel_tac add_left_cancel_le;
   124 end);
   125 
   126 
   127 (* nat diff *)
   128 
   129 structure DiffCancelSums = CancelSumsFun
   130 (struct
   131   open Sum;
   132   val mk_bal = HOLogic.mk_binop "op -";
   133   val dest_bal = HOLogic.dest_bin "op -" HOLogic.natT;
   134   val uncancel_tac = gen_uncancel_tac diff_cancel;
   135 end);
   136 
   137 
   138 
   139 (** prepare nat_cancel simprocs **)
   140 
   141 fun prep_pat s = Thm.read_cterm (Theory.sign_of (the_context ())) 
   142                                 (s, HOLogic.termT);
   143 val prep_pats = map prep_pat;
   144 
   145 fun prep_simproc (name, pats, proc) = Simplifier.mk_simproc name pats proc;
   146 
   147 val eq_pats = prep_pats ["(l::nat) + m = n", "(l::nat) = m + n", "Suc m = n", 
   148                          "m = Suc n"];
   149 val less_pats = prep_pats ["(l::nat) + m < n", "(l::nat) < m + n", "Suc m < n",
   150                            "m < Suc n"];
   151 val le_pats = prep_pats ["(l::nat) + m <= n", "(l::nat) <= m + n", 
   152                          "Suc m <= n", "m <= Suc n"];
   153 val diff_pats = prep_pats ["((l::nat) + m) - n", "(l::nat) - (m + n)", 
   154                            "Suc m - n", "m - Suc n"];
   155 
   156 val nat_cancel_sums_add = map prep_simproc
   157   [("nateq_cancel_sums",   eq_pats,   EqCancelSums.proc),
   158    ("natless_cancel_sums", less_pats, LessCancelSums.proc),
   159    ("natle_cancel_sums",   le_pats,   LeCancelSums.proc)];
   160 
   161 val nat_cancel_sums = nat_cancel_sums_add @
   162   [prep_simproc("natdiff_cancel_sums", diff_pats, DiffCancelSums.proc)];
   163 
   164 
   165 end;
   166 
   167 open ArithData;
   168 
   169 
   170 (*---------------------------------------------------------------------------*)
   171 (* 2. Linear arithmetic                                                      *)
   172 (*---------------------------------------------------------------------------*)
   173 
   174 (* Parameters data for general linear arithmetic functor *)
   175 
   176 structure LA_Logic: LIN_ARITH_LOGIC =
   177 struct
   178 val ccontr = ccontr;
   179 val conjI = conjI;
   180 val neqE = linorder_neqE;
   181 val notI = notI;
   182 val sym = sym;
   183 val not_lessD = linorder_not_less RS iffD1;
   184 val not_leD = linorder_not_le RS iffD1;
   185 
   186 
   187 fun mk_Eq thm = (thm RS Eq_FalseI) handle THM _ => (thm RS Eq_TrueI);
   188 
   189 val mk_Trueprop = HOLogic.mk_Trueprop;
   190 
   191 fun neg_prop(TP$(Const("Not",_)$t)) = TP$t
   192   | neg_prop(TP$t) = TP $ (Const("Not",HOLogic.boolT-->HOLogic.boolT)$t);
   193 
   194 fun is_False thm =
   195   let val _ $ t = #prop(rep_thm thm)
   196   in t = Const("False",HOLogic.boolT) end;
   197 
   198 fun is_nat(t) = fastype_of1 t = HOLogic.natT;
   199 
   200 fun mk_nat_thm sg t =
   201   let val ct = cterm_of sg t  and cn = cterm_of sg (Var(("n",0),HOLogic.natT))
   202   in instantiate ([],[(cn,ct)]) le0 end;
   203 
   204 end;
   205 
   206 
   207 (* arith theory data *)
   208 
   209 structure ArithTheoryDataArgs =
   210 struct
   211   val name = "HOL/arith";
   212   type T = {splits: thm list, inj_consts: (string * typ)list, discrete: (string * bool) list};
   213 
   214   val empty = {splits = [], inj_consts = [], discrete = []};
   215   val copy = I;
   216   val prep_ext = I;
   217   fun merge ({splits= splits1, inj_consts= inj_consts1, discrete= discrete1},
   218              {splits= splits2, inj_consts= inj_consts2, discrete= discrete2}) =
   219    {splits = Drule.merge_rules (splits1, splits2),
   220     inj_consts = merge_lists inj_consts1 inj_consts2,
   221     discrete = merge_alists discrete1 discrete2};
   222   fun print _ _ = ();
   223 end;
   224 
   225 structure ArithTheoryData = TheoryDataFun(ArithTheoryDataArgs);
   226 
   227 fun arith_split_add (thy, thm) = (ArithTheoryData.map (fn {splits,inj_consts,discrete} =>
   228   {splits= thm::splits, inj_consts= inj_consts, discrete= discrete}) thy, thm);
   229 
   230 fun arith_discrete d = ArithTheoryData.map (fn {splits,inj_consts,discrete} =>
   231   {splits = splits, inj_consts = inj_consts, discrete = d :: discrete});
   232 
   233 fun arith_inj_const c = ArithTheoryData.map (fn {splits,inj_consts,discrete} =>
   234   {splits = splits, inj_consts = c :: inj_consts, discrete = discrete});
   235 
   236 
   237 structure LA_Data_Ref: LIN_ARITH_DATA =
   238 struct
   239 
   240 (* Decomposition of terms *)
   241 
   242 fun nT (Type("fun",[N,_])) = N = HOLogic.natT
   243   | nT _ = false;
   244 
   245 fun add_atom(t,m,(p,i)) = (case assoc(p,t) of None => ((t,m)::p,i)
   246                            | Some n => (overwrite(p,(t,ratadd(n,m))), i));
   247 
   248 exception Zero;
   249 
   250 fun rat_of_term(numt,dent) =
   251   let val num = HOLogic.dest_binum numt and den = HOLogic.dest_binum dent
   252   in if den = 0 then raise Zero else int_ratdiv(num,den) end;
   253 
   254 (* Warning: in rare cases number_of encloses a non-numeral,
   255    in which case dest_binum raises TERM; hence all the handles below.
   256 *)
   257 
   258 (* decompose nested multiplications, bracketing them to the right and combining all
   259    their coefficients
   260 *)
   261 
   262 fun demult((mC as Const("op *",_)) $ s $ t,m) = ((case s of
   263         Const("Numeral.number_of",_)$n
   264         => demult(t,ratmul(m,rat_of_int(HOLogic.dest_binum n)))
   265       | Const("op *",_) $ s1 $ s2 => demult(mC $ s1 $ (mC $ s2 $ t),m)
   266       | Const("HOL.divide",_) $ numt $ (Const("Numeral.number_of",_)$dent) =>
   267           let val den = HOLogic.dest_binum dent
   268           in if den = 0 then raise Zero
   269              else demult(mC $ numt $ t,ratmul(m, ratinv(rat_of_int den)))
   270           end
   271       | _ => atomult(mC,s,t,m)
   272       ) handle TERM _ => atomult(mC,s,t,m))
   273   | demult(atom as Const("HOL.divide",_) $ t $ (Const("Numeral.number_of",_)$dent), m) =
   274       (let val den = HOLogic.dest_binum dent
   275        in if den = 0 then raise Zero else demult(t,ratmul(m, ratinv(rat_of_int den))) end
   276        handle TERM _ => (Some atom,m))
   277   | demult(t as Const("Numeral.number_of",_)$n,m) =
   278       ((None,ratmul(m,rat_of_int(HOLogic.dest_binum n)))
   279        handle TERM _ => (Some t,m))
   280   | demult(atom,m) = (Some atom,m)
   281 
   282 and atomult(mC,atom,t,m) = (case demult(t,m) of (None,m') => (Some atom,m')
   283                             | (Some t',m') => (Some(mC $ atom $ t'),m'))
   284 
   285 fun decomp2 inj_consts (rel,lhs,rhs) =
   286 let
   287 (* Turn term into list of summand * multiplicity plus a constant *)
   288 fun poly(Const("op +",_) $ s $ t, m, pi) = poly(s,m,poly(t,m,pi))
   289   | poly(all as Const("op -",T) $ s $ t, m, pi) =
   290       if nT T then add_atom(all,m,pi)
   291       else poly(s,m,poly(t,ratneg m,pi))
   292   | poly(Const("uminus",_) $ t, m, pi) = poly(t,ratneg m,pi)
   293   | poly(Const("0",_), _, pi) = pi
   294   | poly(Const("Suc",_)$t, m, (p,i)) = poly(t, m, (p,ratadd(i,m)))
   295   | poly(t as Const("op *",_) $ _ $ _, m, pi as (p,i)) =
   296       (case demult(t,m) of
   297          (None,m') => (p,ratadd(i,m))
   298        | (Some u,m') => add_atom(u,m',pi))
   299   | poly(t as Const("HOL.divide",_) $ _ $ _, m, pi as (p,i)) =
   300       (case demult(t,m) of
   301          (None,m') => (p,ratadd(i,m))
   302        | (Some u,m') => add_atom(u,m',pi))
   303   | poly(all as (Const("Numeral.number_of",_)$t,m,(p,i))) =
   304       ((p,ratadd(i,ratmul(m,rat_of_int(HOLogic.dest_binum t))))
   305        handle TERM _ => add_atom all)
   306   | poly(all as Const f $ x, m, pi) =
   307       if f mem inj_consts then poly(x,m,pi) else add_atom(all,m,pi)
   308   | poly x  = add_atom x;
   309 
   310 val (p,i) = poly(lhs,rat_of_int 1,([],rat_of_int 0))
   311 and (q,j) = poly(rhs,rat_of_int 1,([],rat_of_int 0))
   312 
   313   in case rel of
   314        "op <"  => Some(p,i,"<",q,j)
   315      | "op <=" => Some(p,i,"<=",q,j)
   316      | "op ="  => Some(p,i,"=",q,j)
   317      | _       => None
   318   end handle Zero => None;
   319 
   320 fun negate(Some(x,i,rel,y,j,d)) = Some(x,i,"~"^rel,y,j,d)
   321   | negate None = None;
   322 
   323 fun decomp1 (discrete,inj_consts) (T,xxx) =
   324   (case T of
   325      Type("fun",[Type(D,[]),_]) =>
   326        (case assoc(discrete,D) of
   327           None => None
   328         | Some d => (case decomp2 inj_consts xxx of
   329                        None => None
   330                      | Some(p,i,rel,q,j) => Some(p,i,rel,q,j,d)))
   331    | _ => None);
   332 
   333 fun decomp2 data (_$(Const(rel,T)$lhs$rhs)) = decomp1 data (T,(rel,lhs,rhs))
   334   | decomp2 data (_$(Const("Not",_)$(Const(rel,T)$lhs$rhs))) =
   335       negate(decomp1 data (T,(rel,lhs,rhs)))
   336   | decomp2 data _ = None
   337 
   338 fun decomp sg =
   339   let val {discrete, inj_consts, ...} = ArithTheoryData.get_sg sg
   340   in decomp2 (discrete,inj_consts) end
   341 
   342 fun number_of(n,T) = HOLogic.number_of_const T $ (HOLogic.mk_bin n)
   343 
   344 end;
   345 
   346 
   347 structure Fast_Arith =
   348   Fast_Lin_Arith(structure LA_Logic=LA_Logic and LA_Data=LA_Data_Ref);
   349 
   350 val fast_arith_tac = Fast_Arith.lin_arith_tac
   351 and trace_arith    = Fast_Arith.trace;
   352 
   353 local
   354 
   355 (* reduce contradictory <= to False.
   356    Most of the work is done by the cancel tactics.
   357 *)
   358 val add_rules = [add_0,add_0_right,Zero_not_Suc,Suc_not_Zero,le_0_eq];
   359 
   360 val add_mono_thms_nat = map (fn s => prove_goal (the_context ()) s
   361  (fn prems => [cut_facts_tac prems 1,
   362                blast_tac (claset() addIs [add_le_mono]) 1]))
   363 ["(i <= j) & (k <= l) ==> i + k <= j + (l::nat)",
   364  "(i  = j) & (k <= l) ==> i + k <= j + (l::nat)",
   365  "(i <= j) & (k  = l) ==> i + k <= j + (l::nat)",
   366  "(i  = j) & (k  = l) ==> i + k  = j + (l::nat)"
   367 ];
   368 
   369 in
   370 
   371 val init_lin_arith_data =
   372  Fast_Arith.setup @
   373  [Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, simpset = _} =>
   374    {add_mono_thms = add_mono_thms @ add_mono_thms_nat,
   375     mult_mono_thms = mult_mono_thms,
   376     inj_thms = inj_thms,
   377     lessD = lessD @ [Suc_leI],
   378     simpset = HOL_basic_ss addsimps add_rules addsimprocs nat_cancel_sums_add}),
   379   ArithTheoryData.init, arith_discrete ("nat", true)];
   380 
   381 end;
   382 
   383 
   384 local
   385 val nat_arith_simproc_pats =
   386   map (fn s => Thm.read_cterm (Theory.sign_of (the_context ())) (s, HOLogic.boolT))
   387       ["(m::nat) < n","(m::nat) <= n", "(m::nat) = n"];
   388 in
   389 val fast_nat_arith_simproc = mk_simproc
   390   "fast_nat_arith" nat_arith_simproc_pats Fast_Arith.lin_arith_prover;
   391 end;
   392 
   393 (* Because of fast_nat_arith_simproc, the arithmetic solver is really only
   394 useful to detect inconsistencies among the premises for subgoals which are
   395 *not* themselves (in)equalities, because the latter activate
   396 fast_nat_arith_simproc anyway. However, it seems cheaper to activate the
   397 solver all the time rather than add the additional check. *)
   398 
   399 
   400 (* arith proof method *)
   401 
   402 (* FIXME: K true should be replaced by a sensible test to speed things up
   403    in case there are lots of irrelevant terms involved;
   404    elimination of min/max can be optimized:
   405    (max m n + k <= r) = (m+k <= r & n+k <= r)
   406    (l <= min m n + k) = (l <= m+k & l <= n+k)
   407 *)
   408 local
   409 
   410 fun raw_arith_tac i st =
   411   refute_tac (K true) (REPEAT o split_tac (#splits (ArithTheoryData.get_sg (Thm.sign_of_thm st))))
   412              ((REPEAT_DETERM o etac linorder_neqE) THEN' fast_arith_tac) i st;
   413 
   414 in
   415 
   416 val arith_tac = fast_arith_tac ORELSE' (atomize_tac THEN' raw_arith_tac);
   417 
   418 fun arith_method prems =
   419   Method.METHOD (fn facts => HEADGOAL (Method.insert_tac (prems @ facts) THEN' arith_tac));
   420 
   421 end;
   422 
   423 
   424 (* theory setup *)
   425 
   426 val arith_setup =
   427  [Simplifier.change_simpset_of (op addsimprocs) nat_cancel_sums] @
   428   init_lin_arith_data @
   429   [Simplifier.change_simpset_of (op addSolver)
   430    (mk_solver "lin. arith." Fast_Arith.cut_lin_arith_tac),
   431   Simplifier.change_simpset_of (op addsimprocs) [fast_nat_arith_simproc],
   432   Method.add_methods [("arith", (arith_method o #2) oo Method.syntax Args.bang_facts,
   433     "decide linear arithmethic")],
   434   Attrib.add_attributes [("arith_split",
   435     (Attrib.no_args arith_split_add, Attrib.no_args Attrib.undef_local_attribute),
   436     "declaration of split rules for arithmetic procedure")]];