add support for function application to measurability prover
authorhoelzl
Tue, 06 Nov 2012 19:18:35 +0100
changeset 50021 d96a3f468203
parent 50020 6b9611abcd4c
child 50022 286dfcab9833
add support for function application to measurability prover
src/HOL/Probability/Borel_Space.thy
src/HOL/Probability/Finite_Product_Measure.thy
src/HOL/Probability/Independent_Family.thy
src/HOL/Probability/Lebesgue_Integration.thy
src/HOL/Probability/Radon_Nikodym.thy
src/HOL/Probability/Sigma_Algebra.thy
--- a/src/HOL/Probability/Borel_Space.thy	Tue Nov 06 15:15:33 2012 +0100
+++ b/src/HOL/Probability/Borel_Space.thy	Tue Nov 06 19:18:35 2012 +0100
@@ -74,7 +74,7 @@
     using assms[of S] by simp
 qed
 
-lemma borel_measurable_const[measurable (raw)]:
+lemma borel_measurable_const:
   "(\<lambda>x. c) \<in> borel_measurable M"
   by auto
 
@@ -168,7 +168,8 @@
   shows borel_measurable_le[measurable]: "{w \<in> space M. f w \<le> g w} \<in> sets M"
     and borel_measurable_eq[measurable]: "{w \<in> space M. f w = g w} \<in> sets M"
     and borel_measurable_neq: "{w \<in> space M. f w \<noteq> g w} \<in> sets M"
-  unfolding eq_iff not_less[symmetric] by measurable+
+  unfolding eq_iff not_less[symmetric]
+  by measurable
 
 subsection "Borel space equals sigma algebras over intervals"
 
--- a/src/HOL/Probability/Finite_Product_Measure.thy	Tue Nov 06 15:15:33 2012 +0100
+++ b/src/HOL/Probability/Finite_Product_Measure.thy	Tue Nov 06 19:18:35 2012 +0100
@@ -480,7 +480,7 @@
   shows "(PIE j:I. E j) \<in> sets (PIM i:I. M i)"
   using sets_PiM_I[of I I E M] sets_into_space[OF sets] `finite I` sets by auto
 
-lemma measurable_component_singleton[measurable (raw)]:
+lemma measurable_component_singleton:
   assumes "i \<in> I" shows "(\<lambda>x. x i) \<in> measurable (Pi\<^isub>M I M) (M i)"
 proof (unfold measurable_def, intro CollectI conjI ballI)
   fix A assume "A \<in> sets (M i)"
@@ -491,6 +491,18 @@
     using `A \<in> sets (M i)` `i \<in> I` by (auto intro!: sets_PiM_I)
 qed (insert `i \<in> I`, auto simp: space_PiM)
 
+lemma measurable_component_singleton'[measurable_app]:
+  assumes f: "f \<in> measurable N (Pi\<^isub>M I M)"
+  assumes i: "i \<in> I"
+  shows "(\<lambda>x. (f x) i) \<in> measurable N (M i)"
+  using measurable_compose[OF f measurable_component_singleton, OF i] .
+
+lemma measurable_nat_case[measurable (raw)]:
+  assumes [measurable (raw)]: "i = 0 \<Longrightarrow> f \<in> measurable M N"
+    "\<And>j. i = Suc j \<Longrightarrow> (\<lambda>x. g x j) \<in> measurable M N"
+  shows "(\<lambda>x. nat_case (f x) (g x) i) \<in> measurable M N"
+  by (cases i) simp_all
+
 lemma measurable_add_dim[measurable]:
   "(\<lambda>(f, y). f(i := y)) \<in> measurable (Pi\<^isub>M I M \<Otimes>\<^isub>M M i) (Pi\<^isub>M (insert i I) M)"
     (is "?f \<in> measurable ?P ?I")
--- a/src/HOL/Probability/Independent_Family.thy	Tue Nov 06 15:15:33 2012 +0100
+++ b/src/HOL/Probability/Independent_Family.thy	Tue Nov 06 19:18:35 2012 +0100
@@ -1004,9 +1004,6 @@
   "A \<in> sets N \<Longrightarrow> f \<in> measurable M N \<Longrightarrow> (\<lambda>x. indicator A (f x)) \<in> borel_measurable M"
   using measurable_comp[OF _ borel_measurable_indicator, of f M N A] by (auto simp add: comp_def)
 
-lemma measurable_id_prod[measurable (raw)]: "i = j \<Longrightarrow> (\<lambda>x. x) \<in> measurable (M i) (M j)"
-  by simp
-
 lemma (in product_sigma_finite) distr_component:
   "distr (M i) (Pi\<^isub>M {i} M) (\<lambda>x. \<lambda>i\<in>{i}. x) = Pi\<^isub>M {i} M" (is "?D = ?P")
 proof (intro measure_eqI[symmetric])
--- a/src/HOL/Probability/Lebesgue_Integration.thy	Tue Nov 06 15:15:33 2012 +0100
+++ b/src/HOL/Probability/Lebesgue_Integration.thy	Tue Nov 06 19:18:35 2012 +0100
@@ -270,7 +270,7 @@
     have "simple_function M (\<lambda>x. real (f x i))"
     proof (intro simple_function_borel_measurable)
       show "(\<lambda>x. real (f x i)) \<in> borel_measurable M"
-        using u by (auto intro!: measurable_If simp: real_f)
+        using u by (auto simp: real_f)
       have "(\<lambda>x. real (f x i))`space M \<subseteq> real`{..i*2^i}"
         using f_upper[of _ i] by auto
       then show "finite ((\<lambda>x. real (f x i))`space M)"
--- a/src/HOL/Probability/Radon_Nikodym.thy	Tue Nov 06 15:15:33 2012 +0100
+++ b/src/HOL/Probability/Radon_Nikodym.thy	Tue Nov 06 19:18:35 2012 +0100
@@ -922,8 +922,7 @@
   then show "AE x in M. f x = f' x"
     unfolding eventually_ae_filter using h_borel pos
     by (auto simp add: h_null_sets null_sets_density_iff not_less[symmetric]
-                          AE_iff_null_sets[symmetric])
-       blast
+                          AE_iff_null_sets[symmetric]) blast
 qed
 
 lemma (in sigma_finite_measure) density_unique_iff:
@@ -1126,7 +1125,7 @@
   fixes T :: "'a \<Rightarrow> 'b"
   assumes T: "T \<in> measurable M M'" and T': "T' \<in> measurable M' M"
     and inv: "\<forall>x\<in>space M. T' (T x) = x"
-  and ac: "absolutely_continuous (distr M M' T) (distr N M' T)"
+  and ac[simp]: "absolutely_continuous (distr M M' T) (distr N M' T)"
   and N: "sets N = sets M"
   shows "AE x in M. RN_deriv (distr M M' T) (distr N M' T) (T x) = RN_deriv M N x"
 proof (rule RN_deriv_unique)
@@ -1162,7 +1161,7 @@
     qed
   qed
   have "(RN_deriv ?M' ?N') \<circ> T \<in> borel_measurable M"
-    using T ac by measurable simp
+    using T ac by measurable
   then show "(\<lambda>x. RN_deriv ?M' ?N' (T x)) \<in> borel_measurable M"
     by (simp add: comp_def)
   show "AE x in M. 0 \<le> RN_deriv ?M' ?N' (T x)" using M'.RN_deriv_nonneg[OF ac] by auto
--- a/src/HOL/Probability/Sigma_Algebra.thy	Tue Nov 06 15:15:33 2012 +0100
+++ b/src/HOL/Probability/Sigma_Algebra.thy	Tue Nov 06 19:18:35 2012 +0100
@@ -1331,8 +1331,10 @@
 lemma measurable_ident: "id \<in> measurable M M"
   by (auto simp add: measurable_def)
 
-lemma measurable_ident': "(\<lambda>x. x) \<in> measurable M M"
-  by (auto simp add: measurable_def)
+lemma measurable_ident_sets:
+  assumes eq: "sets M = sets M'" shows "(\<lambda>x. x) \<in> measurable M M'"
+  using measurable_ident[of M]
+  unfolding id_def measurable_def eq sets_eq_imp_space_eq[OF eq] .
 
 lemma sets_Least:
   assumes meas: "\<And>i::nat. {x\<in>space M. P i x} \<in> M"
@@ -1497,28 +1499,58 @@
 
 structure Data = Generic_Data
 (
-  type T = (thm list * thm list) * thm list;
-  val empty = (([], []), []);
+  type T = {
+    concrete_thms : thm Item_Net.T,
+    generic_thms : thm Item_Net.T,
+    dest_thms : thm Item_Net.T,
+    app_thms : thm Item_Net.T }
+  val empty = {
+    concrete_thms = Thm.full_rules,
+    generic_thms = Thm.full_rules,
+    dest_thms = Thm.full_rules,
+    app_thms = Thm.full_rules};
   val extend = I;
-  val merge = fn (((c1, g1), d1), ((c2, g2), d2)) => ((c1 @ c2, g1 @ g2), d1 @ d2);
+  fun merge (t1, t2) = {
+    concrete_thms = Item_Net.merge (#concrete_thms t1, #concrete_thms t2),
+    generic_thms = Item_Net.merge (#generic_thms t1, #generic_thms t2),
+    dest_thms = Item_Net.merge (#dest_thms t1, #dest_thms t2),
+    app_thms = Item_Net.merge (#app_thms t1, #app_thms t2) };
 );
 
 val debug =
   Attrib.setup_config_bool @{binding measurable_debug} (K false)
 
 val backtrack =
-  Attrib.setup_config_int @{binding measurable_backtrack} (K 40)
+  Attrib.setup_config_int @{binding measurable_backtrack} (K 20)
+
+val split =
+  Attrib.setup_config_bool @{binding measurable_split} (K true)
 
-fun get lv = (case lv of Concrete => fst | Generic => snd) o fst o Data.get o Context.Proof; 
+fun TAKE n tac = Seq.take n o tac
+
+fun get lv =
+  rev o Item_Net.content o (case lv of Concrete => #concrete_thms | Generic => #generic_thms) o
+  Data.get o Context.Proof;
+
 fun get_all ctxt = get Concrete ctxt @ get Generic ctxt;
 
-fun update f lv = Data.map (apfst (case lv of Concrete => apfst f | Generic => apsnd f));
-fun add thms' = update (fn thms => thms @ thms');
+fun map_data f1 f2 f3 f4
+  {generic_thms = t1,    concrete_thms = t2,    dest_thms = t3,    app_thms = t4} =
+  {generic_thms = f1 t1, concrete_thms = f2 t2, dest_thms = f3 t3, app_thms = f4 t4 }
+
+fun map_concrete_thms f = map_data f I I I
+fun map_generic_thms f = map_data I f I I
+fun map_dest_thms f = map_data I I f I
+fun map_app_thms f = map_data I I I f
 
-val get_dest = snd o Data.get;
-fun add_dest thm = Data.map (apsnd (fn thms => thms @ [thm]));
+fun update f lv = Data.map (case lv of Concrete => map_concrete_thms f | Generic => map_generic_thms f);
+fun add thms' = update (fold Item_Net.update thms');
 
-fun TRYALL' tacs = fold_rev (curry op APPEND') tacs (K no_tac);
+val get_dest = Item_Net.content o #dest_thms o Data.get;
+val add_dest = Data.map o map_dest_thms o Item_Net.update;
+
+val get_app = Item_Net.content o #app_thms o Data.get;
+val add_app = Data.map o map_app_thms o Item_Net.update;
 
 fun is_too_generic thm =
   let 
@@ -1531,9 +1563,7 @@
 
 fun add_thm (raw, lv) thm ctxt = add (if raw then [thm] else import_theorem ctxt thm) lv ctxt;
 
-fun debug_tac ctxt msg f = if Config.get ctxt debug then K (print_tac (msg ())) THEN' f else f
-
-fun TAKE n f thm = Seq.take n (f thm)
+fun debug_tac ctxt msg f = if Config.get ctxt debug then print_tac (msg ()) THEN f else f
 
 fun nth_hol_goal thm i =
   HOLogic.dest_Trueprop (Logic.strip_imp_concl (strip_all_body (nth (prems_of thm) (i - 1))))
@@ -1543,6 +1573,13 @@
     (Const (@{const_name "Set.member"}, _) $ f $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => f
   | _ => raise (TERM ("not a measurability predicate", [t])))
 
+fun is_cond_formula n thm = if length (prems_of thm) < n then false else
+  (case nth_hol_goal thm n of
+    (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "sets"}, _) $ _)) => false
+  | (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => false
+  | _ => true)
+  handle TERM _ => true;
+
 fun indep (Bound i) t b = i < b orelse t <= i
   | indep (f $ t) top bot = indep f top bot andalso indep t top bot
   | indep (Abs (_,_,t)) top bot = indep t (top + 1) (bot + 1)
@@ -1557,51 +1594,89 @@
           in
             map (fn (f', t) => (f' $ g, t)) (cnt_walk f Ts) @
             map (fn (g', t) => (f $ g', t)) (cnt_walk g Ts) @
-            (if is_countable (fastype_of1 (Ts, g)) andalso loose_bvar1 (g, n)
+            (if is_countable (type_of1 (Ts, g)) andalso loose_bvar1 (g, n)
                 andalso indep g n 0 andalso g <> Bound n
               then [(f $ Bound (n + 1), incr_boundvars (~ n) g)]
               else [])
           end
         | cnt_walk _ _ = []
     in map (fn (t1, t2) => let
-        val T1 = fastype_of1 ([T], t2)
-        val T2 = fastype_of1 ([T], t)
+        val T1 = type_of1 ([T], t2)
+        val T2 = type_of1 ([T], t)
       in ([SOME (Abs (n, T1, Abs (n, T, t1))), NONE, NONE, SOME (Abs (n, T, t2))],
         [SOME T1, SOME T, SOME T2])
       end) (cnt_walk t [T])
     end
   | cnt_prefixes _ _ = []
 
-val split_fun_tac =
+val split_countable_tac =
   Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
     let
       val f = dest_measurable_fun (HOLogic.dest_Trueprop t)
       fun cert f = map (Option.map (f (Proof_Context.theory_of ctxt)))
       fun inst t (ts, Ts) = Drule.instantiate' (cert ctyp_of Ts) (cert cterm_of ts) t
       val cps = cnt_prefixes ctxt f |> map (inst @{thm measurable_compose_countable})
-    in if null cps then no_tac else debug_tac ctxt (K "split fun") (resolve_tac cps) i end
+    in if null cps then no_tac else debug_tac ctxt (K ("split countable fun")) (resolve_tac cps i) end
     handle TERM _ => no_tac) 1)
 
-fun single_measurable_tac ctxt facts =
-  debug_tac ctxt (fn () => "single + " ^
-    Pretty.str_of (Pretty.block (Pretty.commas (map (Syntax.pretty_term ctxt o prop_of) (maps (import_theorem (Context.Proof ctxt)) facts)))))
-  (resolve_tac ((maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf) facts) @ get_all ctxt)
-    APPEND' (split_fun_tac ctxt));
+fun measurable_tac' ctxt ss facts = let
+
+    val imported_thms =
+      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf) facts) @ get_all ctxt
+
+    fun debug_facts msg () =
+      msg ^ " + " ^ Pretty.str_of (Pretty.list "[" "]"
+        (map (Syntax.pretty_term ctxt o prop_of) (maps (import_theorem (Context.Proof ctxt)) facts)));
+
+    val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
+
+    val split_app_tac =
+      Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
+        let
+          fun app_prefixes (Abs (n, T, (f $ g))) = let
+                val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
+              in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
+            | app_prefixes _ = []
 
-fun is_cond_formlua n thm = if length (prems_of thm) < n then false else
-  (case nth_hol_goal thm n of
-    (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "sets"}, _) $ _)) => false
-  | (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => false
-  | _ => true)
-  handle TERM _ => true;
+          fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
+            | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
+          val thy = Proof_Context.theory_of ctxt
+          val tunify = Sign.typ_unify thy
+          val thms = map
+              (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
+              (get_app (Context.Proof ctxt))
+          fun cert f = map (fn (t, t') => (f thy t, f thy t'))
+          fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
+            let
+              val inst =
+                (Vartab.empty, ~1)
+                |> tunify (T, thmT)
+                |> tunify (Tf, thmTf)
+                |> tunify (Tc, thmTc)
+                |> Vartab.dest o fst
+              val subst = subst_TVars (map (apsnd snd) inst)
+            in
+              Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
+                cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
+            end
+          val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
+        in if null cps then no_tac
+            else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
+              ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
+        handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
+        handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
 
-fun measurable_tac' ctxt ss facts n =
-  TAKE (Config.get ctxt backtrack)
-  ((single_measurable_tac ctxt facts THEN'
-   REPEAT o (single_measurable_tac ctxt facts APPEND'
-             SOLVED' (fn n => COND (is_cond_formlua n) (debug_tac ctxt (K "simp") (asm_full_simp_tac ss) n) no_tac))) n);
+    val depth_measurable_tac = REPEAT
+      (COND (is_cond_formula 1)
+        (debug_tac ctxt (K "simp") (SOLVED' (asm_full_simp_tac ss) 1))
+        ((debug_tac ctxt (K "single") (resolve_tac imported_thms 1)) APPEND
+          (split_app_tac ctxt 1) APPEND
+          (splitter 1)))
 
-fun measurable_tac ctxt = measurable_tac' ctxt (simpset_of ctxt);
+  in debug_tac ctxt (debug_facts "start") depth_measurable_tac end;
+
+fun measurable_tac ctxt facts =
+  TAKE (Config.get ctxt backtrack) (measurable_tac' ctxt (simpset_of ctxt) facts);
 
 val attr_add = Thm.declaration_attribute o add_thm;
 
@@ -1612,14 +1687,17 @@
 val dest_attr : attribute context_parser =
   Scan.lift (Scan.succeed (Thm.declaration_attribute add_dest));
 
+val app_attr : attribute context_parser =
+  Scan.lift (Scan.succeed (Thm.declaration_attribute add_app));
+
 val method : (Proof.context -> Method.method) context_parser =
-  Scan.lift (Scan.succeed (fn ctxt => METHOD (fn facts => measurable_tac ctxt facts 1)));
+  Scan.lift (Scan.succeed (fn ctxt => METHOD (fn facts => measurable_tac ctxt facts)));
 
 fun simproc ss redex = let
     val ctxt = Simplifier.the_context ss;
     val t = HOLogic.mk_Trueprop (term_of redex);
     fun tac {context = ctxt, ...} =
-      SOLVE (measurable_tac' ctxt ss (Simplifier.prems_of ss) 1);
+      SOLVE (measurable_tac' ctxt ss (Simplifier.prems_of ss));
   in try (fn () => Goal.prove ctxt [] [] t tac RS @{thm Eq_TrueI}) () end;
 
 end
@@ -1628,6 +1706,7 @@
 
 attribute_setup measurable = {* Measurable.attr *} "declaration of measurability theorems"
 attribute_setup measurable_dest = {* Measurable.dest_attr *} "add dest rule for measurability prover"
+attribute_setup measurable_app = {* Measurable.app_attr *} "add application rule for measurability prover"
 method_setup measurable = {* Measurable.method *} "measurability prover"
 simproc_setup measurable ("A \<in> sets M" | "f \<in> measurable M N") = {* K Measurable.simproc *}
 
@@ -1646,8 +1725,7 @@
 declare
   measurable_count_space[measurable (raw)]
   measurable_ident[measurable (raw)]
-  measurable_ident'[measurable (raw)]
-  measurable_count_space_const[measurable (raw)]
+  measurable_ident_sets[measurable (raw)]
   measurable_const[measurable (raw)]
   measurable_If[measurable (raw)]
   measurable_comp[measurable (raw)]
@@ -1686,6 +1764,7 @@
   "pred M (\<lambda>x. Q x) \<Longrightarrow> pred M (\<lambda>x. P x) \<Longrightarrow> pred M (\<lambda>x. Q x = P x)"
   "pred M (\<lambda>x. f x \<in> UNIV)"
   "pred M (\<lambda>x. f x \<in> {})"
+  "pred M (\<lambda>x. P' (f x)) \<Longrightarrow> pred M (\<lambda>x. f x \<in> {x. P' x})"
   "pred M (\<lambda>x. f x \<in> (B x)) \<Longrightarrow> pred M (\<lambda>x. f x \<in> - (B x))"
   "pred M (\<lambda>x. f x \<in> (A x)) \<Longrightarrow> pred M (\<lambda>x. f x \<in> (B x)) \<Longrightarrow> pred M (\<lambda>x. f x \<in> (A x) - (B x))"
   "pred M (\<lambda>x. f x \<in> (A x)) \<Longrightarrow> pred M (\<lambda>x. f x \<in> (B x)) \<Longrightarrow> pred M (\<lambda>x. f x \<in> (A x) \<inter> (B x))"
@@ -1765,7 +1844,8 @@
   Int[measurable (raw)]
 
 lemma pred_in_If[measurable (raw)]:
-  "pred M (\<lambda>x. (P x \<longrightarrow> x \<in> A x) \<and> (\<not> P x \<longrightarrow> x \<in> B x)) \<Longrightarrow> pred M (\<lambda>x. x \<in> (if P x then A x else B x))"
+  "(P \<Longrightarrow> pred M (\<lambda>x. x \<in> A x)) \<Longrightarrow> (\<not> P \<Longrightarrow> pred M (\<lambda>x. x \<in> B x)) \<Longrightarrow>
+    pred M (\<lambda>x. x \<in> (if P then A x else B x))"
   by auto
 
 lemma sets_range[measurable_dest]: