src/HOL/Probability/measurable.ML
changeset 59047 8d7cec9b861d
parent 58965 a62cdcc5344b
child 59048 7dc8ac6f0895
--- a/src/HOL/Probability/measurable.ML	Mon Nov 24 12:35:13 2014 +0100
+++ b/src/HOL/Probability/measurable.ML	Mon Nov 24 12:20:35 2014 +0100
@@ -8,20 +8,16 @@
 sig
   datatype level = Concrete | Generic
 
-  val add_app : thm -> Context.generic -> Context.generic
-  val add_dest : thm -> Context.generic -> Context.generic
-  val add_thm : bool * level -> thm -> Context.generic -> Context.generic
-  val del_thm : bool * level -> thm -> Context.generic -> Context.generic
-  val add_del_thm : bool -> (bool * level) -> thm -> Context.generic -> Context.generic
+  val app_thm_attr : attribute context_parser
+  val dest_thm_attr : attribute context_parser
+  val measurable_thm_attr : bool * (bool * level) -> attribute
 
   val measurable_tac : Proof.context -> thm list -> tactic
 
   val simproc : Proof.context -> cterm -> thm option
 
-  val get : level -> Proof.context -> thm list
+  val get_thms : Proof.context -> thm list
   val get_all : Proof.context -> thm list
-
-  val update : (thm Item_Net.T -> thm Item_Net.T) -> level -> Context.generic -> Context.generic
 end ;
 
 structure Measurable : MEASURABLE =
@@ -29,23 +25,23 @@
 
 datatype level = Concrete | Generic;
 
+fun eq_measurable_thms ((th1, d1), (th2, d2)) = 
+  d1 = d2 andalso Thm.eq_thm_prop (th1, th2) ;
+
 structure Data = Generic_Data
 (
   type T = {
-    concrete_thms : thm Item_Net.T,
-    generic_thms : thm Item_Net.T,
+    measurable_thms : (thm * (bool * level)) Item_Net.T,
     dest_thms : thm Item_Net.T,
     app_thms : thm Item_Net.T }
   val empty = {
-    concrete_thms = Thm.full_rules,
-    generic_thms = Thm.full_rules,
+    measurable_thms = Item_Net.init eq_measurable_thms (single o Thm.prop_of o fst),
     dest_thms = Thm.full_rules,
-    app_thms = Thm.full_rules};
+    app_thms = Thm.full_rules };
   val extend = I;
-  fun merge ({concrete_thms = ct1, generic_thms = gt1, dest_thms = dt1, app_thms = at1 },
-      {concrete_thms = ct2, generic_thms = gt2, dest_thms = dt2, app_thms = at2 }) = {
-    concrete_thms = Item_Net.merge (ct1, ct2),
-    generic_thms = Item_Net.merge (gt1, gt2),
+  fun merge ({measurable_thms = t1, dest_thms = dt1, app_thms = at1 },
+      {measurable_thms = t2, dest_thms = dt2, app_thms = at2 }) = {
+    measurable_thms = Item_Net.merge (t1, t2),
     dest_thms = Item_Net.merge (dt1, dt2),
     app_thms = Item_Net.merge (at1, at2) };
 );
@@ -53,38 +49,36 @@
 val debug =
   Attrib.setup_config_bool @{binding measurable_debug} (K false)
 
-val backtrack =
-  Attrib.setup_config_int @{binding measurable_backtrack} (K 20)
-
 val split =
   Attrib.setup_config_bool @{binding measurable_split} (K true)
 
-fun TAKE n tac = Seq.take n o tac
+fun map_data f1 f2 f3
+  {measurable_thms = t1,    dest_thms = t2,    app_thms = t3} =
+  {measurable_thms = f1 t1, dest_thms = f2 t2, app_thms = f3 t3 }
 
-fun get lv =
-  rev o Item_Net.content o (case lv of Concrete => #concrete_thms | Generic => #generic_thms) o
-  Data.get o Context.Proof;
-
-fun get_all ctxt = get Concrete ctxt @ get Generic ctxt;
+fun map_measurable_thms f = map_data f I I
+fun map_dest_thms f = map_data I f I
+fun map_app_thms f = map_data I I f
 
-fun map_data f1 f2 f3 f4
-  {generic_thms = t1,    concrete_thms = t2,    dest_thms = t3,    app_thms = t4} =
-  {generic_thms = f1 t1, concrete_thms = f2 t2, dest_thms = f3 t3, app_thms = f4 t4 }
+fun generic_add_del map = 
+  Scan.lift
+    (Args.add >> K Item_Net.update || Args.del >> K Item_Net.remove || Scan.succeed Item_Net.update) >>
+    (fn f => Thm.declaration_attribute (Data.map o map o f))
+
+val app_thm_attr = generic_add_del map_app_thms
 
-fun map_concrete_thms f = map_data f I I I
-fun map_generic_thms f = map_data I f I I
-fun map_dest_thms f = map_data I I f I
-fun map_app_thms f = map_data I I I f
+val dest_thm_attr = generic_add_del map_dest_thms
 
-fun update f lv = Data.map (case lv of Concrete => map_concrete_thms f | Generic => map_generic_thms f);
-fun add thms' = update (fold Item_Net.update thms');
-fun del thms' = update (fold Item_Net.remove thms');
+fun del_thm th net =
+  let
+    val thms = net |> Item_Net.content |> filter (fn (th', _) => Thm.eq_thm (th, th'))
+  in fold Item_Net.remove thms net end ;
+
+fun measurable_thm_attr (do_add, d) = Thm.declaration_attribute
+  (Data.map o map_measurable_thms o (if do_add then Item_Net.update o rpair d else del_thm))
 
 val get_dest = Item_Net.content o #dest_thms o Data.get;
-val add_dest = Data.map o map_dest_thms o Item_Net.update;
-
 val get_app = Item_Net.content o #app_thms o Data.get;
-val add_app = Data.map o map_app_thms o Item_Net.update;
 
 fun is_too_generic thm =
   let 
@@ -95,12 +89,18 @@
 fun import_theorem ctxt thm = if is_too_generic thm then [] else
   [thm] @ map_filter (try (fn th' => thm RS th')) (get_dest ctxt);
 
-fun add_del_thm_gen f (raw, lv) thm ctxt = f (if raw then [thm] else import_theorem ctxt thm) lv ctxt;
+val get = Context.Proof #> Data.get #> #measurable_thms #> Item_Net.content ;
+
+val get_all = get #> map fst ;
 
-val add_thm = add_del_thm_gen add;
-val del_thm = add_del_thm_gen del;
-fun add_del_thm true = add_thm
-  | add_del_thm false = del_thm
+fun get_thms ctxt =
+  let
+    val thms = ctxt |> get |> rev ;
+    fun get lv = map_filter (fn (th, (rw, lv')) => if lv = lv' then SOME (th, rw) else NONE) thms
+  in
+    get Concrete @ get Generic |>
+    maps (fn (th, rw) => if rw then [th] else import_theorem (Context.Proof ctxt) th)
+  end;
 
 fun debug_tac ctxt msg f = if Config.get ctxt debug then print_tac ctxt (msg ()) THEN f else f
 
@@ -158,10 +158,46 @@
     in if null cps then no_tac else debug_tac ctxt (K "split countable fun") (resolve_tac cps i) end
     handle TERM _ => no_tac) 1)
 
-fun measurable_tac' ctxt facts =
+val split_app_tac =
+  Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
+    let
+      fun app_prefixes (Abs (n, T, (f $ g))) = let
+            val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
+          in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
+        | app_prefixes _ = []
+
+      fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
+        | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
+      val thy = Proof_Context.theory_of ctxt
+      val tunify = Sign.typ_unify thy
+      val thms = map
+          (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
+          (get_app (Context.Proof ctxt))
+      fun cert f = map (fn (t, t') => (f thy t, f thy t'))
+      fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
+        let
+          val inst =
+            (Vartab.empty, ~1)
+            |> tunify (T, thmT)
+            |> tunify (Tf, thmTf)
+            |> tunify (Tc, thmTc)
+            |> Vartab.dest o fst
+          val subst = subst_TVars (map (apsnd snd) inst)
+        in
+          Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
+            cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
+        end
+      val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
+    in if null cps then no_tac
+        else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
+          ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
+    handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
+    handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
+
+fun measurable_tac ctxt facts =
   let
     val imported_thms =
-      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf ctxt) facts) @ get_all ctxt
+      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf ctxt) facts) @ get_thms ctxt
 
     fun debug_facts msg () =
       msg ^ " + " ^ Pretty.str_of (Pretty.list "[" "]"
@@ -169,42 +205,6 @@
 
     val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
 
-    val split_app_tac =
-      Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
-        let
-          fun app_prefixes (Abs (n, T, (f $ g))) = let
-                val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
-              in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
-            | app_prefixes _ = []
-
-          fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
-            | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
-          val thy = Proof_Context.theory_of ctxt
-          val tunify = Sign.typ_unify thy
-          val thms = map
-              (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
-              (get_app (Context.Proof ctxt))
-          fun cert f = map (fn (t, t') => (f thy t, f thy t'))
-          fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
-            let
-              val inst =
-                (Vartab.empty, ~1)
-                |> tunify (T, thmT)
-                |> tunify (Tf, thmTf)
-                |> tunify (Tc, thmTc)
-                |> Vartab.dest o fst
-              val subst = subst_TVars (map (apsnd snd) inst)
-            in
-              Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
-                cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
-            end
-          val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
-        in if null cps then no_tac
-            else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
-              ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
-        handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
-        handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
-
     fun REPEAT_cnt f n st = ((f n THEN REPEAT_cnt f (n + 1)) ORELSE all_tac) st
 
     val depth_measurable_tac = REPEAT_cnt (fn n =>
@@ -216,14 +216,11 @@
 
   in debug_tac ctxt (debug_facts "start") depth_measurable_tac end;
 
-fun measurable_tac ctxt facts =
-  TAKE (Config.get ctxt backtrack) (measurable_tac' ctxt facts);
-
 fun simproc ctxt redex =
   let
     val t = HOLogic.mk_Trueprop (term_of redex);
     fun tac {context = ctxt, prems = _ } =
-      SOLVE (measurable_tac' ctxt (Simplifier.prems_of ctxt));
+      SOLVE (measurable_tac ctxt (Simplifier.prems_of ctxt));
   in try (fn () => Goal.prove ctxt [] [] t tac RS @{thm Eq_TrueI}) () end;
 
 end