(* Title: Provers/Arith/assoc_fold.ML
ID: $Id$
Author: Lawrence C Paulson, Cambridge University Computer Laboratory
Copyright 1999 University of Cambridge
Simplification procedure for associative operators + and * on numeric types
Performs constant folding when the literals are separated, as in 3+n+4.
*)
signature ASSOC_FOLD_DATA =
sig
val ss : simpset (*basic simpset of object-logtic*)
val eq_reflection : thm (*object-equality to meta-equality*)
val thy_ref : theory_ref (*the operator's signature*)
val add_ac : thm list (*AC-rewrites for plus*)
end;
functor Assoc_Fold (Data: ASSOC_FOLD_DATA) =
struct
val assoc_ss = Data.ss addsimps Data.add_ac;
exception Assoc_fail;
fun mk_sum plus [] = raise Assoc_fail
| mk_sum plus tms = foldr1 (fn (x,y) => plus $ x $ y) tms;
(*Separate the literals from the other terms being combined*)
fun sift_terms plus (t, (lits,others)) =
case t of
Const("Numeral.number_of", _) $ _ =>
(t::lits, others) (*new literal*)
| (f as Const _) $ x $ y =>
if f = plus
then sift_terms plus (x, sift_terms plus (y, (lits,others)))
else (lits, t::others) (*arbitrary summand*)
| _ => (lits, t::others);
val trace = ref false;
(*Make a simproc to combine all literals in a associative nest*)
fun proc thy _ lhs =
let fun show t = string_of_cterm (Thm.cterm_of thy t)
val _ = if !trace then tracing ("assoc_fold simproc: LHS = " ^ show lhs)
else ()
val plus =
(case lhs of f $ _ $ _ => f | _ => error "Assoc_fold: bad pattern")
val (lits,others) = sift_terms plus (lhs, ([],[]))
val _ = if length lits < 2
then raise Assoc_fail (*we can't reduce the number of terms*)
else ()
val rhs = plus $ mk_sum plus lits $ mk_sum plus others
val _ = if !trace then tracing ("RHS = " ^ show rhs) else ()
val th = Tactic.prove thy [] [] (Logic.mk_equals (lhs, rhs))
(fn _ => rtac Data.eq_reflection 1 THEN
simp_tac assoc_ss 1)
in SOME th end
handle Assoc_fail => NONE;
end;
(*test data:
set timing;
Goal "(#3 * (a * #34)) * (#2 * b * #9) = (x::int)";
Goal "a + b + c + d + e + f + g + h + i + j + k + l + m + n + oo + p + q + r + s + t + u + v + (w + x + y + z + a + #2 + b + #2 + c + #2 + d + #2 + e) + #2 + f + (#2 + g + #2 + h + #2 + i) + #2 + (j + #2 + k + #2 + l + #2 + m + #2) + n + #2 + (oo + #2 + p + #2 + q + #2 + r) + #2 + s + #2 + t + #2 + u + #2 + v + #2 + w + #2 + x + #2 + y + #2 + z + #2 = (uu::nat)";
*)