src/HOL/Tools/lin_arith.ML
changeset 31100 6a2e67fe4488
parent 31082 54a442b2d727
child 31101 26c7bb764a38
--- a/src/HOL/Tools/lin_arith.ML	Mon May 11 11:53:21 2009 +0200
+++ b/src/HOL/Tools/lin_arith.ML	Mon May 11 15:18:32 2009 +0200
@@ -6,27 +6,25 @@
 
 signature BASIC_LIN_ARITH =
 sig
-  val arith_split_add: attribute
-  val lin_arith_pre_tac: Proof.context -> int -> tactic
+  val lin_arith_simproc: simpset -> term -> thm option
+  val fast_nat_arith_simproc: simproc
   val fast_arith_tac: Proof.context -> int -> tactic
   val fast_ex_arith_tac: Proof.context -> bool -> int -> tactic
-  val lin_arith_simproc: simpset -> term -> thm option
-  val fast_nat_arith_simproc: simproc
   val linear_arith_tac: Proof.context -> int -> tactic
 end;
 
 signature LIN_ARITH =
 sig
   include BASIC_LIN_ARITH
-  val add_discrete_type: string -> Context.generic -> Context.generic
+  val pre_tac: Proof.context -> int -> tactic
+  val add_inj_thms: thm list -> Context.generic -> Context.generic
+  val add_lessD: thm -> Context.generic -> Context.generic
+  val add_simps: thm list -> Context.generic -> Context.generic
+  val add_simprocs: simproc list -> Context.generic -> Context.generic
   val add_inj_const: string * typ -> Context.generic -> Context.generic
-  val map_data:
-    ({add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
-      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset} ->
-     {add_mono_thms: thm list, mult_mono_thms: thm list, inj_thms: thm list,
-      lessD: thm list, neqE: thm list, simpset: Simplifier.simpset}) ->
-    Context.generic -> Context.generic
+  val add_discrete_type: string -> Context.generic -> Context.generic
   val setup: Context.generic -> Context.generic
+  val global_setup: theory -> theory
   val split_limit: int Config.T
   val neq_limit: int Config.T
   val warning_count: int ref
@@ -47,37 +45,38 @@
 val sym = sym;
 val not_lessD = @{thm linorder_not_less} RS iffD1;
 val not_leD = @{thm linorder_not_le} RS iffD1;
-val le0 = thm "le0";
 
-fun mk_Eq thm = (thm RS Eq_FalseI) handle THM _ => (thm RS Eq_TrueI);
+fun mk_Eq thm = thm RS Eq_FalseI handle THM _ => thm RS Eq_TrueI;
 
 val mk_Trueprop = HOLogic.mk_Trueprop;
 
 fun atomize thm = case Thm.prop_of thm of
-    Const("Trueprop",_) $ (Const("op &",_) $ _ $ _) =>
-    atomize(thm RS conjunct1) @ atomize(thm RS conjunct2)
+    Const ("Trueprop", _) $ (Const (@{const_name "op &"}, _) $ _ $ _) =>
+    atomize (thm RS conjunct1) @ atomize (thm RS conjunct2)
   | _ => [thm];
 
-fun neg_prop ((TP as Const("Trueprop",_)) $ (Const("Not",_) $ t)) = TP $ t
-  | neg_prop ((TP as Const("Trueprop",_)) $ t) = TP $ (HOLogic.Not $t)
+fun neg_prop ((TP as Const("Trueprop", _)) $ (Const (@{const_name "Not"}, _) $ t)) = TP $ t
+  | neg_prop ((TP as Const("Trueprop", _)) $ t) = TP $ (HOLogic.Not $t)
   | neg_prop t = raise TERM ("neg_prop", [t]);
 
 fun is_False thm =
   let val _ $ t = Thm.prop_of thm
-  in t = Const("False",HOLogic.boolT) end;
+  in t = HOLogic.false_const end;
 
 fun is_nat t = (fastype_of1 t = HOLogic.natT);
 
-fun mk_nat_thm sg t =
-  let val ct = cterm_of sg t  and cn = cterm_of sg (Var(("n",0),HOLogic.natT))
-  in instantiate ([],[(cn,ct)]) le0 end;
+fun mk_nat_thm thy t =
+  let
+    val cn = cterm_of thy (Var (("n", 0), HOLogic.natT))
+    and ct = cterm_of thy t
+  in instantiate ([], [(cn, ct)]) @{thm le0} end;
 
 end;
 
 
 (* arith context data *)
 
-structure ArithContextData = GenericDataFun
+structure Lin_Arith_Data = GenericDataFun
 (
   type T = {splits: thm list,
             inj_consts: (string * typ) list,
@@ -92,27 +91,25 @@
     discrete = Library.merge (op =) (discrete1, discrete2)};
 );
 
-val get_arith_data = ArithContextData.get o Context.Proof;
+val get_arith_data = Lin_Arith_Data.get o Context.Proof;
 
-val arith_split_add = Thm.declaration_attribute (fn thm =>
-  ArithContextData.map (fn {splits, inj_consts, discrete} =>
-    {splits = update Thm.eq_thm_prop thm splits,
-     inj_consts = inj_consts, discrete = discrete}));
+fun add_split thm = Lin_Arith_Data.map (fn {splits, inj_consts, discrete} =>
+  {splits = update Thm.eq_thm_prop thm splits,
+   inj_consts = inj_consts, discrete = discrete});
 
-fun add_discrete_type d = ArithContextData.map (fn {splits, inj_consts, discrete} =>
+fun add_discrete_type d = Lin_Arith_Data.map (fn {splits, inj_consts, discrete} =>
   {splits = splits, inj_consts = inj_consts,
    discrete = update (op =) d discrete});
 
-fun add_inj_const c = ArithContextData.map (fn {splits, inj_consts, discrete} =>
+fun add_inj_const c = Lin_Arith_Data.map (fn {splits, inj_consts, discrete} =>
   {splits = splits, inj_consts = update (op =) c inj_consts,
    discrete = discrete});
 
-val (split_limit, setup1) = Attrib.config_int "linarith_split_limit" 9;
-val (neq_limit, setup2) = Attrib.config_int "linarith_neq_limit" 9;
-val setup_options = setup1 #> setup2;
+val (split_limit, setup_split_limit) = Attrib.config_int "linarith_split_limit" 9;
+val (neq_limit, setup_neq_limit) = Attrib.config_int "linarith_neq_limit" 9;
 
 
-structure LA_Data_Ref =
+structure LA_Data =
 struct
 
 val fast_arith_neq_limit = neq_limit;
@@ -756,15 +753,32 @@
   )
 end;
 
-end;  (* LA_Data_Ref *)
+end;  (* LA_Data *)
 
 
-val lin_arith_pre_tac = LA_Data_Ref.pre_tac;
+val pre_tac = LA_Data.pre_tac;
 
-structure Fast_Arith = Fast_Lin_Arith(structure LA_Logic = LA_Logic and LA_Data = LA_Data_Ref);
+structure Fast_Arith = Fast_Lin_Arith(structure LA_Logic = LA_Logic and LA_Data = LA_Data);
 
 val map_data = Fast_Arith.map_data;
 
+fun map_inj_thms f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} =
+  {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = f inj_thms,
+    lessD = lessD, neqE = neqE, simpset = simpset};
+
+fun map_lessD f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} =
+  {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
+    lessD = f lessD, neqE = neqE, simpset = simpset};
+
+fun map_simpset f {add_mono_thms, mult_mono_thms, inj_thms, lessD, neqE, simpset} =
+  {add_mono_thms = add_mono_thms, mult_mono_thms = mult_mono_thms, inj_thms = inj_thms,
+    lessD = lessD, neqE = neqE, simpset = f simpset};
+
+fun add_inj_thms thms = Fast_Arith.map_data (map_inj_thms (append thms));
+fun add_lessD thm = Fast_Arith.map_data (map_lessD (fn thms => thms @ [thm]));
+fun add_simps thms = Fast_Arith.map_data (map_simpset (fn simpset => simpset addsimps thms));
+fun add_simprocs procs = Fast_Arith.map_data (map_simpset (fn simpset => simpset addsimprocs procs));
+
 fun fast_arith_tac ctxt = Fast_Arith.lin_arith_tac ctxt false;
 val fast_ex_arith_tac = Fast_Arith.lin_arith_tac;
 val trace = Fast_Arith.trace;
@@ -774,7 +788,7 @@
    Most of the work is done by the cancel tactics. *)
 
 val init_arith_data =
- map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, ...} =>
+  Fast_Arith.map_data (fn {add_mono_thms, mult_mono_thms, inj_thms, lessD, ...} =>
    {add_mono_thms = @{thms add_mono_thms_ordered_semiring} @ @{thms add_mono_thms_ordered_field} @ add_mono_thms,
     mult_mono_thms = @{thm mult_strict_left_mono} :: @{thm mult_left_mono} :: mult_mono_thms,
     inj_thms = inj_thms,
@@ -891,17 +905,18 @@
   init_arith_data #>
   Simplifier.map_ss (fn ss => ss addsimprocs [fast_nat_arith_simproc]
     addSolver (mk_solver' "lin_arith"
-      (add_arith_facts #> Fast_Arith.cut_lin_arith_tac))) #>
-  Context.mapping
-   (setup_options #>
-    Arith_Data.add_tactic "linear arithmetic" gen_linear_arith_tac #>
-    Method.setup @{binding linarith}
-      (Args.bang_facts >> (fn prems => fn ctxt =>
-        METHOD (fn facts =>
-          HEADGOAL (Method.insert_tac (prems @ Arith_Data.get_arith_facts ctxt @ facts)
-            THEN' linear_arith_tac ctxt)))) "linear arithmetic" #>
-    Attrib.setup @{binding arith_split} (Scan.succeed arith_split_add)
-      "declaration of split rules for arithmetic procedure") I;
+      (add_arith_facts #> Fast_Arith.cut_lin_arith_tac)))
+
+val global_setup =
+  setup_split_limit #> setup_neq_limit #>
+  Attrib.setup @{binding arith_split} (Scan.succeed (Thm.declaration_attribute add_split))
+    "declaration of split rules for arithmetic procedure" #>
+  Method.setup @{binding linarith}
+    (Args.bang_facts >> (fn prems => fn ctxt =>
+      METHOD (fn facts =>
+        HEADGOAL (Method.insert_tac (prems @ Arith_Data.get_arith_facts ctxt @ facts)
+          THEN' linear_arith_tac ctxt)))) "linear arithmetic" #>
+  Arith_Data.add_tactic "linear arithmetic" gen_linear_arith_tac;
 
 end;