src/HOL/Probability/measurable.ML
changeset 59048 7dc8ac6f0895
parent 59047 8d7cec9b861d
child 59353 f0707dc3d9aa
--- a/src/HOL/Probability/measurable.ML	Mon Nov 24 12:20:35 2014 +0100
+++ b/src/HOL/Probability/measurable.ML	Mon Nov 24 12:20:14 2014 +0100
@@ -6,44 +6,66 @@
 
 signature MEASURABLE = 
 sig
+  type preprocessor = thm -> Proof.context -> (thm list * Proof.context)
+
   datatype level = Concrete | Generic
 
   val app_thm_attr : attribute context_parser
   val dest_thm_attr : attribute context_parser
+  val cong_thm_attr : attribute context_parser
   val measurable_thm_attr : bool * (bool * level) -> attribute
 
+  val add_del_cong_thm : bool -> thm -> Context.generic -> Context.generic ;
+
+  val get_all : Context.generic -> thm list
+  val get_dest : Context.generic -> thm list
+  val get_cong : Context.generic -> thm list
+
   val measurable_tac : Proof.context -> thm list -> tactic
 
   val simproc : Proof.context -> cterm -> thm option
 
-  val get_thms : Proof.context -> thm list
-  val get_all : Proof.context -> thm list
+  val add_preprocessor : string -> preprocessor -> Context.generic -> Context.generic
+  val del_preprocessor : string -> Context.generic -> Context.generic
+  val add_local_cong : thm -> Proof.context -> Proof.context
 end ;
 
 structure Measurable : MEASURABLE =
 struct
 
+type preprocessor = thm -> Proof.context -> (thm list * Proof.context)
+
 datatype level = Concrete | Generic;
 
 fun eq_measurable_thms ((th1, d1), (th2, d2)) = 
   d1 = d2 andalso Thm.eq_thm_prop (th1, th2) ;
 
+fun merge_dups (xs:(string * preprocessor) list) ys =
+  xs @ (filter (fn (name, _) => is_none (find_first (fn (name', _) => name' = name) xs)) ys) 
+
 structure Data = Generic_Data
 (
   type T = {
     measurable_thms : (thm * (bool * level)) Item_Net.T,
     dest_thms : thm Item_Net.T,
-    app_thms : thm Item_Net.T }
+    app_thms : thm Item_Net.T,
+    cong_thms : thm Item_Net.T,
+    preprocessors : (string * preprocessor) list }
   val empty = {
     measurable_thms = Item_Net.init eq_measurable_thms (single o Thm.prop_of o fst),
     dest_thms = Thm.full_rules,
-    app_thms = Thm.full_rules };
+    app_thms = Thm.full_rules,
+    cong_thms = Thm.full_rules,
+    preprocessors = [] };
   val extend = I;
-  fun merge ({measurable_thms = t1, dest_thms = dt1, app_thms = at1 },
-      {measurable_thms = t2, dest_thms = dt2, app_thms = at2 }) = {
+  fun merge ({measurable_thms = t1, dest_thms = dt1, app_thms = at1, cong_thms = ct1, preprocessors = i1 },
+      {measurable_thms = t2, dest_thms = dt2, app_thms = at2, cong_thms = ct2, preprocessors = i2 }) = {
     measurable_thms = Item_Net.merge (t1, t2),
     dest_thms = Item_Net.merge (dt1, dt2),
-    app_thms = Item_Net.merge (at1, at2) };
+    app_thms = Item_Net.merge (at1, at2),
+    cong_thms = Item_Net.merge (ct1, ct2),
+    preprocessors = merge_dups i1 i2 
+    };
 );
 
 val debug =
@@ -52,13 +74,15 @@
 val split =
   Attrib.setup_config_bool @{binding measurable_split} (K true)
 
-fun map_data f1 f2 f3
-  {measurable_thms = t1,    dest_thms = t2,    app_thms = t3} =
-  {measurable_thms = f1 t1, dest_thms = f2 t2, app_thms = f3 t3 }
+fun map_data f1 f2 f3 f4 f5
+  {measurable_thms = t1,    dest_thms = t2,    app_thms = t3,    cong_thms = t4,    preprocessors = t5 } =
+  {measurable_thms = f1 t1, dest_thms = f2 t2, app_thms = f3 t3, cong_thms = f4 t4, preprocessors = f5 t5}
 
-fun map_measurable_thms f = map_data f I I
-fun map_dest_thms f = map_data I f I
-fun map_app_thms f = map_data I I f
+fun map_measurable_thms f = map_data f I I I I
+fun map_dest_thms f = map_data I f I I I
+fun map_app_thms f = map_data I I f I I
+fun map_cong_thms f = map_data I I I f I
+fun map_preprocessors f = map_data I I I I f
 
 fun generic_add_del map = 
   Scan.lift
@@ -69,6 +93,8 @@
 
 val dest_thm_attr = generic_add_del map_dest_thms
 
+val cong_thm_attr = generic_add_del map_cong_thms
+
 fun del_thm th net =
   let
     val thms = net |> Item_Net.content |> filter (fn (th', _) => Thm.eq_thm (th, th'))
@@ -80,29 +106,30 @@
 val get_dest = Item_Net.content o #dest_thms o Data.get;
 val get_app = Item_Net.content o #app_thms o Data.get;
 
+val get_cong = Item_Net.content o #cong_thms o Data.get;
+val add_cong = Data.map o map_cong_thms o Item_Net.update;
+val del_cong = Data.map o map_cong_thms o Item_Net.remove;
+fun add_del_cong_thm true = add_cong
+  | add_del_cong_thm false = del_cong
+
+fun add_preprocessor name f = Data.map (map_preprocessors (fn xs => xs @ [(name, f)]))
+fun del_preprocessor name = Data.map (map_preprocessors (filter (fn (n, _) => n <> name)))
+val add_local_cong = Context.proof_map o add_cong
+
+val get_preprocessors = Context.Proof #> Data.get #> #preprocessors ;
+
 fun is_too_generic thm =
   let 
     val concl = concl_of thm
     val concl' = HOLogic.dest_Trueprop concl handle TERM _ => concl
   in is_Var (head_of concl') end
 
-fun import_theorem ctxt thm = if is_too_generic thm then [] else
-  [thm] @ map_filter (try (fn th' => thm RS th')) (get_dest ctxt);
-
-val get = Context.Proof #> Data.get #> #measurable_thms #> Item_Net.content ;
-
-val get_all = get #> map fst ;
+val get_thms = Data.get #> #measurable_thms #> Item_Net.content ;
 
-fun get_thms ctxt =
-  let
-    val thms = ctxt |> get |> rev ;
-    fun get lv = map_filter (fn (th, (rw, lv')) => if lv = lv' then SOME (th, rw) else NONE) thms
-  in
-    get Concrete @ get Generic |>
-    maps (fn (th, rw) => if rw then [th] else import_theorem (Context.Proof ctxt) th)
-  end;
+val get_all = get_thms #> map fst ;
 
-fun debug_tac ctxt msg f = if Config.get ctxt debug then print_tac ctxt (msg ()) THEN f else f
+fun debug_tac ctxt msg f =
+  if Config.get ctxt debug then print_tac ctxt (msg ()) THEN f else f
 
 fun nth_hol_goal thm i =
   HOLogic.dest_Trueprop (Logic.strip_imp_concl (strip_all_body (nth (prems_of thm) (i - 1))))
@@ -112,7 +139,7 @@
     (Const (@{const_name "Set.member"}, _) $ f $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => f
   | _ => raise (TERM ("not a measurability predicate", [t])))
 
-fun is_cond_formula n thm = if length (prems_of thm) < n then false else
+fun not_measurable_prop n thm = if length (prems_of thm) < n then false else
   (case nth_hol_goal thm n of
     (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "sets"}, _) $ _)) => false
   | (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => false
@@ -148,73 +175,118 @@
     end
   | cnt_prefixes _ _ = []
 
-val split_countable_tac =
-  Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
-    let
-      val f = dest_measurable_fun (HOLogic.dest_Trueprop t)
-      fun cert f = map (Option.map (f (Proof_Context.theory_of ctxt)))
-      fun inst t (ts, Ts) = Drule.instantiate' (cert ctyp_of Ts) (cert cterm_of ts) t
-      val cps = cnt_prefixes ctxt f |> map (inst @{thm measurable_compose_countable})
-    in if null cps then no_tac else debug_tac ctxt (K "split countable fun") (resolve_tac cps i) end
-    handle TERM _ => no_tac) 1)
-
-val split_app_tac =
-  Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
-    let
-      fun app_prefixes (Abs (n, T, (f $ g))) = let
-            val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
-          in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
-        | app_prefixes _ = []
-
-      fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
-        | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
-      val thy = Proof_Context.theory_of ctxt
-      val tunify = Sign.typ_unify thy
-      val thms = map
-          (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
-          (get_app (Context.Proof ctxt))
-      fun cert f = map (fn (t, t') => (f thy t, f thy t'))
-      fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
-        let
-          val inst =
-            (Vartab.empty, ~1)
-            |> tunify (T, thmT)
-            |> tunify (Tf, thmTf)
-            |> tunify (Tc, thmTc)
-            |> Vartab.dest o fst
-          val subst = subst_TVars (map (apsnd snd) inst)
-        in
-          Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
-            cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
-        end
-      val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
-    in if null cps then no_tac
-        else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
-          ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
-    handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
-    handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
-
 fun measurable_tac ctxt facts =
   let
-    val imported_thms =
-      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf ctxt) facts) @ get_thms ctxt
+    fun debug_fact msg thm () =
+      msg ^ " " ^ Pretty.str_of (Syntax.pretty_term ctxt (prop_of thm))
+
+    fun IF' c t i = COND (c i) (t i) no_tac
+
+    fun r_tac msg =
+      if Config.get ctxt debug
+      then FIRST' o
+        map (fn thm => resolve_tac [thm]
+          THEN' K (debug_tac ctxt (debug_fact (msg ^ "resolved using") thm) all_tac))
+      else resolve_tac
+
+    val elem_congI = @{lemma "A = B \<Longrightarrow> x \<in> B \<Longrightarrow> x \<in> A" by simp}
+
+    val dests = get_dest (Context.Proof ctxt)
+    fun prep_dest thm =
+      (if is_too_generic thm then [] else [thm] @ map_filter (try (fn th' => thm RS th')) dests) ;
+    val preprocessors = (("std", prep_dest #> pair) :: get_preprocessors ctxt) ;
+    fun preprocess_thm (thm, raw) =
+      if raw then pair [thm] else fold_map (fn (_, proc) => proc thm) preprocessors #>> flat
+    
+    fun sel lv (th, (raw, lv')) = if lv = lv' then SOME (th, raw) else NONE ;
+    fun get lv = ctxt |> Context.Proof |> get_thms |> rev |> map_filter (sel lv) ;
+    val pre_thms = map (Simplifier.norm_hhf ctxt #> rpair false) facts @ get Concrete @ get Generic
+
+    val (thms, ctxt) = fold_map preprocess_thm pre_thms ctxt |>> flat
 
-    fun debug_facts msg () =
-      msg ^ " + " ^ Pretty.str_of (Pretty.list "[" "]"
-        (map (Syntax.pretty_term ctxt o prop_of) (maps (import_theorem (Context.Proof ctxt)) facts)));
+    fun is_sets_eq (Const (@{const_name "HOL.eq"}, _) $
+          (Const (@{const_name "sets"}, _) $ _) $
+          (Const (@{const_name "sets"}, _) $ _)) = true
+      | is_sets_eq (Const (@{const_name "HOL.eq"}, _) $
+          (Const (@{const_name "measurable"}, _) $ _ $ _) $
+          (Const (@{const_name "measurable"}, _) $ _ $ _)) = true
+      | is_sets_eq _ = false
+
+    val cong_thms = get_cong (Context.Proof ctxt) @
+      filter (fn thm => concl_of thm |> HOLogic.dest_Trueprop |> is_sets_eq handle TERM _ => false) facts
+
+    fun sets_cong_tac i =
+      Subgoal.FOCUS (fn {context = ctxt', prems = prems, ...} => (
+        let
+          val ctxt'' = Simplifier.add_prems prems ctxt'
+        in
+          r_tac "cong intro" [elem_congI]
+          THEN' SOLVED' (fn i => REPEAT_DETERM (
+              ((r_tac "cong solve" (cong_thms @ [@{thm refl}])
+                ORELSE' IF' (fn i => fn thm => nprems_of thm > i)
+                  (SOLVED' (asm_full_simp_tac ctxt''))) i)))
+        end) 1) ctxt i
+        THEN flexflex_tac ctxt
+
+    val simp_solver_tac = 
+      IF' not_measurable_prop (debug_tac ctxt (K "simp ") o SOLVED' (asm_full_simp_tac ctxt))
+
+    val split_countable_tac =
+      Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
+        let
+          val f = dest_measurable_fun (HOLogic.dest_Trueprop t)
+          fun cert f = map (Option.map (f (Proof_Context.theory_of ctxt)))
+          fun inst (ts, Ts) =
+            Drule.instantiate' (cert ctyp_of Ts) (cert cterm_of ts) @{thm measurable_compose_countable}
+        in r_tac "split countable" (cnt_prefixes ctxt f |> map inst) i end
+        handle TERM _ => no_tac) 1)
 
     val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
 
-    fun REPEAT_cnt f n st = ((f n THEN REPEAT_cnt f (n + 1)) ORELSE all_tac) st
+    val split_app_tac =
+      Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
+        let
+          fun app_prefixes (Abs (n, T, (f $ g))) = let
+                val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
+              in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
+            | app_prefixes _ = []
+    
+          fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
+            | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
+          val thy = Proof_Context.theory_of ctxt
+          val tunify = Sign.typ_unify thy
+          val thms = map
+              (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
+              (get_app (Context.Proof ctxt))
+          fun cert f = map (fn (t, t') => (f thy t, f thy t'))
+          fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
+            let
+              val inst =
+                (Vartab.empty, ~1)
+                |> tunify (T, thmT)
+                |> tunify (Tf, thmTf)
+                |> tunify (Tc, thmTc)
+                |> Vartab.dest o fst
+              val subst = subst_TVars (map (apsnd snd) inst)
+            in
+              Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
+                cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
+            end
+          val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
+        in if null cps then no_tac
+            else r_tac "split app" cps i ORELSE debug_tac ctxt (fn () => "split app fun FAILED") no_tac end
+        handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
+        handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
 
-    val depth_measurable_tac = REPEAT_cnt (fn n =>
-       (COND (is_cond_formula 1)
-        (debug_tac ctxt (K ("simp " ^ string_of_int n)) (SOLVED' (asm_full_simp_tac ctxt) 1))
-        ((debug_tac ctxt (K ("single " ^ string_of_int n)) (resolve_tac imported_thms 1)) APPEND
-          (split_app_tac ctxt 1) APPEND
-          (splitter 1)))) 0
+    val single_step_tac =
+      simp_solver_tac
+      ORELSE' r_tac "step" thms
+      ORELSE' (split_app_tac ctxt)
+      ORELSE' splitter
+      ORELSE' (CHANGED o sets_cong_tac)
+      ORELSE' (K (debug_tac ctxt (K "backtrack") no_tac))
 
-  in debug_tac ctxt (debug_facts "start") depth_measurable_tac end;
+  in debug_tac ctxt (K "start") (REPEAT (single_step_tac 1)) end;
 
 fun simproc ctxt redex =
   let