|
1 (* Title: ZF/arith_data.ML |
|
2 ID: $Id$ |
|
3 Author: Lawrence C Paulson, Cambridge University Computer Laboratory |
|
4 Copyright 2000 University of Cambridge |
|
5 |
|
6 Arithmetic simplification: cancellation of common terms |
|
7 *) |
|
8 |
|
9 signature ARITH_DATA = |
|
10 sig |
|
11 val nat_cancel: simproc list |
|
12 end; |
|
13 |
|
14 structure ArithData: ARITH_DATA = |
|
15 struct |
|
16 |
|
17 val iT = Ind_Syntax.iT; |
|
18 |
|
19 val zero = Const("0", iT); |
|
20 val succ = Const("succ", iT --> iT); |
|
21 fun mk_succ t = succ $ t; |
|
22 val one = mk_succ zero; |
|
23 |
|
24 (*Not FOLogic.mk_binop, since it calls fastype_of, which can fail*) |
|
25 fun mk_binop_i c (t,u) = Const (c, [iT,iT] ---> iT) $ t $ u; |
|
26 fun mk_binrel_i c (t,u) = Const (c, [iT,iT] ---> oT) $ t $ u; |
|
27 |
|
28 val mk_plus = mk_binop_i "Arith.add"; |
|
29 |
|
30 (*Thus mk_sum[t] yields t+#0; longer sums don't have a trailing zero*) |
|
31 fun mk_sum [] = zero |
|
32 | mk_sum [t,u] = mk_plus (t, u) |
|
33 | mk_sum (t :: ts) = mk_plus (t, mk_sum ts); |
|
34 |
|
35 (*this version ALWAYS includes a trailing zero*) |
|
36 fun long_mk_sum [] = zero |
|
37 | long_mk_sum (t :: ts) = mk_plus (t, mk_sum ts); |
|
38 |
|
39 val dest_plus = FOLogic.dest_bin "Arith.add" iT; |
|
40 |
|
41 (* dest_sum *) |
|
42 |
|
43 fun dest_sum (Const("0",_)) = [] |
|
44 | dest_sum (Const("succ",_) $ t) = one :: dest_sum t |
|
45 | dest_sum (Const("Arith.add",_) $ t $ u) = dest_sum t @ dest_sum u |
|
46 | dest_sum tm = [tm]; |
|
47 |
|
48 (*Apply the given rewrite (if present) just once*) |
|
49 fun gen_trans_tac th2 None = all_tac |
|
50 | gen_trans_tac th2 (Some th) = ALLGOALS (rtac (th RS th2)); |
|
51 |
|
52 (*Use <-> or = depending on the type of t*) |
|
53 fun mk_eq_iff(t,u) = |
|
54 if fastype_of t = iT then FOLogic.mk_eq(t,u) |
|
55 else FOLogic.mk_iff(t,u); |
|
56 |
|
57 |
|
58 fun add_chyps chyps ct = Drule.list_implies (map cprop_of chyps, ct); |
|
59 |
|
60 fun prove_conv name tacs sg hyps (t,u) = |
|
61 if t aconv u then None |
|
62 else |
|
63 let val ct = add_chyps hyps |
|
64 (cterm_of sg (FOLogic.mk_Trueprop (mk_eq_iff(t, u)))) |
|
65 in Some |
|
66 (hyps MRS |
|
67 (prove_goalw_cterm_nocheck [] ct |
|
68 (fn prems => cut_facts_tac prems 1 :: tacs))) |
|
69 handle ERROR => |
|
70 (warning |
|
71 ("Cancellation failed: no typing information? (" ^ name ^ ")"); |
|
72 None) |
|
73 end; |
|
74 |
|
75 fun prep_simproc (name, pats, proc) = Simplifier.mk_simproc name pats proc; |
|
76 fun prep_pat s = Thm.read_cterm (Theory.sign_of (the_context ())) |
|
77 (s, TypeInfer.anyT ["logic"]); |
|
78 val prep_pats = map prep_pat; |
|
79 |
|
80 |
|
81 (*** Use CancelNumerals simproc without binary numerals, |
|
82 just for cancellation ***) |
|
83 |
|
84 val mk_times = mk_binop_i "Arith.mult"; |
|
85 |
|
86 fun mk_prod [] = one |
|
87 | mk_prod [t] = t |
|
88 | mk_prod (t :: ts) = if t = one then mk_prod ts |
|
89 else mk_times (t, mk_prod ts); |
|
90 |
|
91 val dest_times = FOLogic.dest_bin "Arith.mult" iT; |
|
92 |
|
93 fun dest_prod t = |
|
94 let val (t,u) = dest_times t |
|
95 in dest_prod t @ dest_prod u end |
|
96 handle TERM _ => [t]; |
|
97 |
|
98 (*Dummy version: the only arguments are 0 and 1*) |
|
99 fun mk_coeff (0, t) = zero |
|
100 | mk_coeff (1, t) = t |
|
101 | mk_coeff _ = raise TERM("mk_coeff", []); |
|
102 |
|
103 (*Dummy version: the "coefficient" is always 1. |
|
104 In the result, the factors are sorted terms*) |
|
105 fun dest_coeff t = (1, mk_prod (sort Term.term_ord (dest_prod t))); |
|
106 |
|
107 (*Find first coefficient-term THAT MATCHES u*) |
|
108 fun find_first_coeff past u [] = raise TERM("find_first_coeff", []) |
|
109 | find_first_coeff past u (t::terms) = |
|
110 let val (n,u') = dest_coeff t |
|
111 in if u aconv u' then (n, rev past @ terms) |
|
112 else find_first_coeff (t::past) u terms |
|
113 end |
|
114 handle TERM _ => find_first_coeff (t::past) u terms; |
|
115 |
|
116 |
|
117 (*Simplify #1*n and n*#1 to n*) |
|
118 val add_0s = [add_0_natify, add_0_right_natify]; |
|
119 val add_succs = [add_succ, add_succ_right]; |
|
120 val mult_1s = [mult_1_natify, mult_1_right_natify]; |
|
121 val tc_rules = [natify_in_nat, add_type, diff_type, mult_type]; |
|
122 val natifys = [natify_0, natify_ident, add_natify1, add_natify2, |
|
123 add_natify1, add_natify2, diff_natify1, diff_natify2]; |
|
124 |
|
125 (*Final simplification: cancel + and **) |
|
126 fun simplify_meta_eq rules = |
|
127 mk_meta_eq o |
|
128 simplify (FOL_ss addeqcongs[eq_cong2,iff_cong2] |
|
129 delsimps iff_simps (*these could erase the whole rule!*) |
|
130 addsimps rules) |
|
131 |
|
132 val final_rules = add_0s @ mult_1s @ [mult_0, mult_0_right]; |
|
133 |
|
134 structure CancelNumeralsCommon = |
|
135 struct |
|
136 val mk_sum = mk_sum |
|
137 val dest_sum = dest_sum |
|
138 val mk_coeff = mk_coeff |
|
139 val dest_coeff = dest_coeff |
|
140 val find_first_coeff = find_first_coeff [] |
|
141 val norm_tac_ss1 = ZF_ss addsimps add_0s@add_succs@mult_1s@add_ac |
|
142 val norm_tac_ss2 = ZF_ss addsimps add_ac@mult_ac@tc_rules@natifys |
|
143 val norm_tac = ALLGOALS (asm_simp_tac norm_tac_ss1) |
|
144 THEN ALLGOALS (asm_simp_tac norm_tac_ss2) |
|
145 val numeral_simp_tac_ss = ZF_ss addsimps add_0s@tc_rules@natifys |
|
146 val numeral_simp_tac = ALLGOALS (asm_simp_tac numeral_simp_tac_ss) |
|
147 val simplify_meta_eq = simplify_meta_eq final_rules |
|
148 end; |
|
149 |
|
150 |
|
151 structure EqCancelNumerals = CancelNumeralsFun |
|
152 (open CancelNumeralsCommon |
|
153 val prove_conv = prove_conv "nateq_cancel_numerals" |
|
154 val mk_bal = FOLogic.mk_eq |
|
155 val dest_bal = FOLogic.dest_bin "op =" iT |
|
156 val bal_add1 = eq_add_iff RS iff_trans |
|
157 val bal_add2 = eq_add_iff RS iff_trans |
|
158 val trans_tac = gen_trans_tac iff_trans |
|
159 ); |
|
160 |
|
161 structure LessCancelNumerals = CancelNumeralsFun |
|
162 (open CancelNumeralsCommon |
|
163 val prove_conv = prove_conv "natless_cancel_numerals" |
|
164 val mk_bal = mk_binrel_i "Ordinal.op <" |
|
165 val dest_bal = FOLogic.dest_bin "Ordinal.op <" iT |
|
166 val bal_add1 = less_add_iff RS iff_trans |
|
167 val bal_add2 = less_add_iff RS iff_trans |
|
168 val trans_tac = gen_trans_tac iff_trans |
|
169 ); |
|
170 |
|
171 structure DiffCancelNumerals = CancelNumeralsFun |
|
172 (open CancelNumeralsCommon |
|
173 val prove_conv = prove_conv "natdiff_cancel_numerals" |
|
174 val mk_bal = mk_binop_i "Arith.diff" |
|
175 val dest_bal = FOLogic.dest_bin "Arith.diff" iT |
|
176 val bal_add1 = diff_add_eq RS trans |
|
177 val bal_add2 = diff_add_eq RS trans |
|
178 val trans_tac = gen_trans_tac trans |
|
179 ); |
|
180 |
|
181 |
|
182 val nat_cancel = |
|
183 map prep_simproc |
|
184 [("nateq_cancel_numerals", |
|
185 prep_pats ["l #+ m = n", "l = m #+ n", |
|
186 "l #* m = n", "l = m #* n", |
|
187 "succ(m) = n", "m = succ(n)"], |
|
188 EqCancelNumerals.proc), |
|
189 ("natless_cancel_numerals", |
|
190 prep_pats ["l #+ m < n", "l < m #+ n", |
|
191 "l #* m < n", "l < m #* n", |
|
192 "succ(m) < n", "m < succ(n)"], |
|
193 LessCancelNumerals.proc), |
|
194 ("natdiff_cancel_numerals", |
|
195 prep_pats ["(l #+ m) #- n", "l #- (m #+ n)", |
|
196 "(l #* m) #- n", "l #- (m #* n)", |
|
197 "succ(m) #- n", "m #- succ(n)"], |
|
198 DiffCancelNumerals.proc)]; |
|
199 |
|
200 end; |
|
201 |
|
202 (*examples: |
|
203 print_depth 22; |
|
204 set timing; |
|
205 set trace_simp; |
|
206 fun test s = (Goal s; by (Asm_simp_tac 1)); |
|
207 |
|
208 test "x #+ y = x #+ z"; |
|
209 test "y #+ x = x #+ z"; |
|
210 test "x #+ y #+ z = x #+ z"; |
|
211 test "y #+ (z #+ x) = z #+ x"; |
|
212 test "x #+ y #+ z = (z #+ y) #+ (x #+ w)"; |
|
213 test "x#*y #+ z = (z #+ y) #+ (y#*x #+ w)"; |
|
214 |
|
215 test "x #+ succ(y) = x #+ z"; |
|
216 test "x #+ succ(y) = succ(z #+ x)"; |
|
217 test "succ(x) #+ succ(y) #+ z = succ(z #+ y) #+ succ(x #+ w)"; |
|
218 |
|
219 test "(x #+ y) #- (x #+ z) = w"; |
|
220 test "(y #+ x) #- (x #+ z) = dd"; |
|
221 test "(x #+ y #+ z) #- (x #+ z) = dd"; |
|
222 test "(y #+ (z #+ x)) #- (z #+ x) = dd"; |
|
223 test "(x #+ y #+ z) #- ((z #+ y) #+ (x #+ w)) = dd"; |
|
224 test "(x#*y #+ z) #- ((z #+ y) #+ (y#*x #+ w)) = dd"; |
|
225 |
|
226 (*BAD occurrence of natify*) |
|
227 test "(x #+ succ(y)) #- (x #+ z) = dd"; |
|
228 |
|
229 test "x #* y2 #+ y #* x2 = y #* x2 #+ x #* y2"; |
|
230 |
|
231 test "(x #+ succ(y)) #- (succ(z #+ x)) = dd"; |
|
232 test "(succ(x) #+ succ(y) #+ z) #- (succ(z #+ y) #+ succ(x #+ w)) = dd"; |
|
233 |
|
234 (*use of typing information*) |
|
235 test "x : nat ==> x #+ y = x"; |
|
236 test "x : nat --> x #+ y = x"; |
|
237 test "x : nat ==> x #+ y < x"; |
|
238 test "x : nat ==> x < y#+x"; |
|
239 |
|
240 (*fails: no typing information isn't visible*) |
|
241 test "x #+ y = x"; |
|
242 |
|
243 test "x #+ y < x #+ z"; |
|
244 test "y #+ x < x #+ z"; |
|
245 test "x #+ y #+ z < x #+ z"; |
|
246 test "y #+ z #+ x < x #+ z"; |
|
247 test "y #+ (z #+ x) < z #+ x"; |
|
248 test "x #+ y #+ z < (z #+ y) #+ (x #+ w)"; |
|
249 test "x#*y #+ z < (z #+ y) #+ (y#*x #+ w)"; |
|
250 |
|
251 test "x #+ succ(y) < x #+ z"; |
|
252 test "x #+ succ(y) < succ(z #+ x)"; |
|
253 test "succ(x) #+ succ(y) #+ z < succ(z #+ y) #+ succ(x #+ w)"; |
|
254 |
|
255 test "x #+ succ(y) le succ(z #+ x)"; |
|
256 *) |