diff -r 42e0a0bfef73 -r d2b1fc1b8e19 src/HOL/Tools/SMT/smt_normalize.ML --- a/src/HOL/Tools/SMT/smt_normalize.ML Mon Dec 06 16:54:22 2010 +0100 +++ b/src/HOL/Tools/SMT/smt_normalize.ML Tue Dec 07 14:53:12 2010 +0100 @@ -23,12 +23,14 @@ (int * thm) list * Proof.context val atomize_conv: Proof.context -> conv val eta_expand_conv: (Proof.context -> conv) -> Proof.context -> conv + val setup: theory -> theory end structure SMT_Normalize: SMT_NORMALIZE = struct structure U = SMT_Utils +structure B = SMT_Builtin infix 2 ?? fun (test ?? f) x = if test x then f x else x @@ -95,6 +97,9 @@ fun rewrite_bool_cases ctxt = map (apsnd ((Term.exists_subterm is_bool_case o Thm.prop_of) ?? Conv.fconv_rule (Conv.top_conv (K unfold_conv) ctxt))) + +val setup_bool_case = B.add_builtin_fun_ext'' @{const_name "bool.bool_case"} + end @@ -203,10 +208,20 @@ val uses_nat_type = Term.exists_type (Term.exists_subtype (equal @{typ nat})) val uses_nat_int = Term.exists_subterm (member (op aconv) [@{const of_nat (int)}, @{const nat}]) + + val nat_ops = [ + @{const less (nat)}, @{const less_eq (nat)}, + @{const Suc}, @{const plus (nat)}, @{const minus (nat)}, + @{const times (nat)}, @{const div (nat)}, @{const mod (nat)}] + val nat_ops' = @{const of_nat (int)} :: @{const nat} :: nat_ops in fun nat_as_int ctxt = map (apsnd ((uses_nat_type o Thm.prop_of) ?? Conv.fconv_rule (conv ctxt))) #> exists (uses_nat_int o Thm.prop_of o snd) ?? append nat_embedding + +val setup_nat_as_int = + B.add_builtin_typ_ext (@{typ nat}, K true) #> + fold (B.add_builtin_fun_ext' o Term.dest_Const) nat_ops' end @@ -263,7 +278,7 @@ | _ => (case Term.strip_comb (Thm.term_of ct) of (Const (c as (_, T)), ts) => - if SMT_Builtin.is_partially_builtin ctxt c + if SMT_Builtin.is_builtin_fun ctxt c ts then eta_args_conv norm_conv (length (Term.binder_types T) - length ts) else args_conv o norm_conv @@ -294,7 +309,7 @@ | _ => (case Term.strip_comb t of (Const (c as (_, T)), ts) => - if SMT_Builtin.is_builtin ctxt c + if SMT_Builtin.is_builtin_fun ctxt c ts then length (Term.binder_types T) = length ts andalso forall (is_normed ctxt) ts else forall (is_normed ctxt) ts @@ -302,6 +317,11 @@ in fun norm_binder_conv ctxt = U.if_conv (is_normed ctxt) Conv.all_conv (norm_conv ctxt) + +val setup_unfolded_quants = + fold B.add_builtin_fun_ext'' [@{const_name Ball}, @{const_name Bex}, + @{const_name Ex1}] + end fun norm_def ctxt thm = @@ -325,6 +345,10 @@ Conv.rewr_conv @{thm atomize_all} | _ => Conv.all_conv) ct +val setup_atomize = + fold B.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="}, + @{const_name all}, @{const_name Trueprop}] + fun normalize_rule ctxt = Conv.fconv_rule ( (* reduce lambda abstractions, except at known binders: *) @@ -554,4 +578,14 @@ |-> (if with_datatypes then datatype_selectors else pair) end + + +(* setup *) + +val setup = + setup_bool_case #> + setup_nat_as_int #> + setup_unfolded_quants #> + setup_atomize + end