src/HOL/Probability/measurable.ML
changeset 50387 3d8863c41fe8
child 51717 9e7d1c139569
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Probability/measurable.ML	Wed Dec 05 15:59:08 2012 +0100
     1.3 @@ -0,0 +1,238 @@
     1.4 +(*  Title:      HOL/Probability/measurable.ML
     1.5 +    Author:     Johannes Hölzl <hoelzl@in.tum.de>
     1.6 +
     1.7 +Measurability prover.
     1.8 +*)
     1.9 +
    1.10 +signature MEASURABLE = 
    1.11 +sig
    1.12 +  datatype level = Concrete | Generic
    1.13 +
    1.14 +  val simproc : simpset -> cterm -> thm option
    1.15 +  val method : (Proof.context -> Method.method) context_parser
    1.16 +  val measurable_tac : Proof.context -> thm list -> tactic
    1.17 +
    1.18 +  val attr : attribute context_parser
    1.19 +  val dest_attr : attribute context_parser
    1.20 +  val app_attr : attribute context_parser
    1.21 +
    1.22 +  val get : level -> Proof.context -> thm list
    1.23 +  val get_all : Proof.context -> thm list
    1.24 +
    1.25 +  val update : (thm Item_Net.T -> thm Item_Net.T) -> level -> Context.generic -> Context.generic
    1.26 +
    1.27 +end ;
    1.28 +
    1.29 +structure Measurable : MEASURABLE =
    1.30 +struct
    1.31 +
    1.32 +datatype level = Concrete | Generic;
    1.33 +
    1.34 +structure Data = Generic_Data
    1.35 +(
    1.36 +  type T = {
    1.37 +    concrete_thms : thm Item_Net.T,
    1.38 +    generic_thms : thm Item_Net.T,
    1.39 +    dest_thms : thm Item_Net.T,
    1.40 +    app_thms : thm Item_Net.T }
    1.41 +  val empty = {
    1.42 +    concrete_thms = Thm.full_rules,
    1.43 +    generic_thms = Thm.full_rules,
    1.44 +    dest_thms = Thm.full_rules,
    1.45 +    app_thms = Thm.full_rules};
    1.46 +  val extend = I;
    1.47 +  fun merge ({concrete_thms = ct1, generic_thms = gt1, dest_thms = dt1, app_thms = at1 },
    1.48 +      {concrete_thms = ct2, generic_thms = gt2, dest_thms = dt2, app_thms = at2 }) = {
    1.49 +    concrete_thms = Item_Net.merge (ct1, ct2),
    1.50 +    generic_thms = Item_Net.merge (gt1, gt2),
    1.51 +    dest_thms = Item_Net.merge (dt1, dt2),
    1.52 +    app_thms = Item_Net.merge (at1, at2) };
    1.53 +);
    1.54 +
    1.55 +val debug =
    1.56 +  Attrib.setup_config_bool @{binding measurable_debug} (K false)
    1.57 +
    1.58 +val backtrack =
    1.59 +  Attrib.setup_config_int @{binding measurable_backtrack} (K 20)
    1.60 +
    1.61 +val split =
    1.62 +  Attrib.setup_config_bool @{binding measurable_split} (K true)
    1.63 +
    1.64 +fun TAKE n tac = Seq.take n o tac
    1.65 +
    1.66 +fun get lv =
    1.67 +  rev o Item_Net.content o (case lv of Concrete => #concrete_thms | Generic => #generic_thms) o
    1.68 +  Data.get o Context.Proof;
    1.69 +
    1.70 +fun get_all ctxt = get Concrete ctxt @ get Generic ctxt;
    1.71 +
    1.72 +fun map_data f1 f2 f3 f4
    1.73 +  {generic_thms = t1,    concrete_thms = t2,    dest_thms = t3,    app_thms = t4} =
    1.74 +  {generic_thms = f1 t1, concrete_thms = f2 t2, dest_thms = f3 t3, app_thms = f4 t4 }
    1.75 +
    1.76 +fun map_concrete_thms f = map_data f I I I
    1.77 +fun map_generic_thms f = map_data I f I I
    1.78 +fun map_dest_thms f = map_data I I f I
    1.79 +fun map_app_thms f = map_data I I I f
    1.80 +
    1.81 +fun update f lv = Data.map (case lv of Concrete => map_concrete_thms f | Generic => map_generic_thms f);
    1.82 +fun add thms' = update (fold Item_Net.update thms');
    1.83 +
    1.84 +val get_dest = Item_Net.content o #dest_thms o Data.get;
    1.85 +val add_dest = Data.map o map_dest_thms o Item_Net.update;
    1.86 +
    1.87 +val get_app = Item_Net.content o #app_thms o Data.get;
    1.88 +val add_app = Data.map o map_app_thms o Item_Net.update;
    1.89 +
    1.90 +fun is_too_generic thm =
    1.91 +  let 
    1.92 +    val concl = concl_of thm
    1.93 +    val concl' = HOLogic.dest_Trueprop concl handle TERM _ => concl
    1.94 +  in is_Var (head_of concl') end
    1.95 +
    1.96 +fun import_theorem ctxt thm = if is_too_generic thm then [] else
    1.97 +  [thm] @ map_filter (try (fn th' => thm RS th')) (get_dest ctxt);
    1.98 +
    1.99 +fun add_thm (raw, lv) thm ctxt = add (if raw then [thm] else import_theorem ctxt thm) lv ctxt;
   1.100 +
   1.101 +fun debug_tac ctxt msg f = if Config.get ctxt debug then print_tac (msg ()) THEN f else f
   1.102 +
   1.103 +fun nth_hol_goal thm i =
   1.104 +  HOLogic.dest_Trueprop (Logic.strip_imp_concl (strip_all_body (nth (prems_of thm) (i - 1))))
   1.105 +
   1.106 +fun dest_measurable_fun t =
   1.107 +  (case t of
   1.108 +    (Const (@{const_name "Set.member"}, _) $ f $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => f
   1.109 +  | _ => raise (TERM ("not a measurability predicate", [t])))
   1.110 +
   1.111 +fun is_cond_formula n thm = if length (prems_of thm) < n then false else
   1.112 +  (case nth_hol_goal thm n of
   1.113 +    (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "sets"}, _) $ _)) => false
   1.114 +  | (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => false
   1.115 +  | _ => true)
   1.116 +  handle TERM _ => true;
   1.117 +
   1.118 +fun indep (Bound i) t b = i < b orelse t <= i
   1.119 +  | indep (f $ t) top bot = indep f top bot andalso indep t top bot
   1.120 +  | indep (Abs (_,_,t)) top bot = indep t (top + 1) (bot + 1)
   1.121 +  | indep _ _ _ = true;
   1.122 +
   1.123 +fun cnt_prefixes ctxt (Abs (n, T, t)) = let
   1.124 +      fun is_countable t = Type.of_sort (Proof_Context.tsig_of ctxt) (t, @{sort countable})
   1.125 +      fun cnt_walk (Abs (ns, T, t)) Ts =
   1.126 +          map (fn (t', t'') => (Abs (ns, T, t'), t'')) (cnt_walk t (T::Ts))
   1.127 +        | cnt_walk (f $ g) Ts = let
   1.128 +            val n = length Ts - 1
   1.129 +          in
   1.130 +            map (fn (f', t) => (f' $ g, t)) (cnt_walk f Ts) @
   1.131 +            map (fn (g', t) => (f $ g', t)) (cnt_walk g Ts) @
   1.132 +            (if is_countable (type_of1 (Ts, g)) andalso loose_bvar1 (g, n)
   1.133 +                andalso indep g n 0 andalso g <> Bound n
   1.134 +              then [(f $ Bound (n + 1), incr_boundvars (~ n) g)]
   1.135 +              else [])
   1.136 +          end
   1.137 +        | cnt_walk _ _ = []
   1.138 +    in map (fn (t1, t2) => let
   1.139 +        val T1 = type_of1 ([T], t2)
   1.140 +        val T2 = type_of1 ([T], t)
   1.141 +      in ([SOME (Abs (n, T1, Abs (n, T, t1))), NONE, NONE, SOME (Abs (n, T, t2))],
   1.142 +        [SOME T1, SOME T, SOME T2])
   1.143 +      end) (cnt_walk t [T])
   1.144 +    end
   1.145 +  | cnt_prefixes _ _ = []
   1.146 +
   1.147 +val split_countable_tac =
   1.148 +  Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
   1.149 +    let
   1.150 +      val f = dest_measurable_fun (HOLogic.dest_Trueprop t)
   1.151 +      fun cert f = map (Option.map (f (Proof_Context.theory_of ctxt)))
   1.152 +      fun inst t (ts, Ts) = Drule.instantiate' (cert ctyp_of Ts) (cert cterm_of ts) t
   1.153 +      val cps = cnt_prefixes ctxt f |> map (inst @{thm measurable_compose_countable})
   1.154 +    in if null cps then no_tac else debug_tac ctxt (K "split countable fun") (resolve_tac cps i) end
   1.155 +    handle TERM _ => no_tac) 1)
   1.156 +
   1.157 +fun measurable_tac' ctxt ss facts = let
   1.158 +
   1.159 +    val imported_thms =
   1.160 +      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf) facts) @ get_all ctxt
   1.161 +
   1.162 +    fun debug_facts msg () =
   1.163 +      msg ^ " + " ^ Pretty.str_of (Pretty.list "[" "]"
   1.164 +        (map (Syntax.pretty_term ctxt o prop_of) (maps (import_theorem (Context.Proof ctxt)) facts)));
   1.165 +
   1.166 +    val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
   1.167 +
   1.168 +    val split_app_tac =
   1.169 +      Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
   1.170 +        let
   1.171 +          fun app_prefixes (Abs (n, T, (f $ g))) = let
   1.172 +                val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
   1.173 +              in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
   1.174 +            | app_prefixes _ = []
   1.175 +
   1.176 +          fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
   1.177 +            | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
   1.178 +          val thy = Proof_Context.theory_of ctxt
   1.179 +          val tunify = Sign.typ_unify thy
   1.180 +          val thms = map
   1.181 +              (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
   1.182 +              (get_app (Context.Proof ctxt))
   1.183 +          fun cert f = map (fn (t, t') => (f thy t, f thy t'))
   1.184 +          fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
   1.185 +            let
   1.186 +              val inst =
   1.187 +                (Vartab.empty, ~1)
   1.188 +                |> tunify (T, thmT)
   1.189 +                |> tunify (Tf, thmTf)
   1.190 +                |> tunify (Tc, thmTc)
   1.191 +                |> Vartab.dest o fst
   1.192 +              val subst = subst_TVars (map (apsnd snd) inst)
   1.193 +            in
   1.194 +              Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
   1.195 +                cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
   1.196 +            end
   1.197 +          val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
   1.198 +        in if null cps then no_tac
   1.199 +            else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
   1.200 +              ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
   1.201 +        handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
   1.202 +        handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
   1.203 +
   1.204 +    fun REPEAT_cnt f n st = ((f n THEN REPEAT_cnt f (n + 1)) ORELSE all_tac) st
   1.205 +
   1.206 +    val depth_measurable_tac = REPEAT_cnt (fn n =>
   1.207 +       (COND (is_cond_formula 1)
   1.208 +        (debug_tac ctxt (K ("simp " ^ string_of_int n)) (SOLVED' (asm_full_simp_tac ss) 1))
   1.209 +        ((debug_tac ctxt (K ("single " ^ string_of_int n)) (resolve_tac imported_thms 1)) APPEND
   1.210 +          (split_app_tac ctxt 1) APPEND
   1.211 +          (splitter 1)))) 0
   1.212 +
   1.213 +  in debug_tac ctxt (debug_facts "start") depth_measurable_tac end;
   1.214 +
   1.215 +fun measurable_tac ctxt facts =
   1.216 +  TAKE (Config.get ctxt backtrack) (measurable_tac' ctxt (simpset_of ctxt) facts);
   1.217 +
   1.218 +val attr_add = Thm.declaration_attribute o add_thm;
   1.219 +
   1.220 +val attr : attribute context_parser =
   1.221 +  Scan.lift (Scan.optional (Args.parens (Scan.optional (Args.$$$ "raw" >> K true) false --
   1.222 +     Scan.optional (Args.$$$ "generic" >> K Generic) Concrete)) (false, Concrete) >> attr_add);
   1.223 +
   1.224 +val dest_attr : attribute context_parser =
   1.225 +  Scan.lift (Scan.succeed (Thm.declaration_attribute add_dest));
   1.226 +
   1.227 +val app_attr : attribute context_parser =
   1.228 +  Scan.lift (Scan.succeed (Thm.declaration_attribute add_app));
   1.229 +
   1.230 +val method : (Proof.context -> Method.method) context_parser =
   1.231 +  Scan.lift (Scan.succeed (fn ctxt => METHOD (fn facts => measurable_tac ctxt facts)));
   1.232 +
   1.233 +fun simproc ss redex = let
   1.234 +    val ctxt = Simplifier.the_context ss;
   1.235 +    val t = HOLogic.mk_Trueprop (term_of redex);
   1.236 +    fun tac {context = ctxt, prems = _ } =
   1.237 +      SOLVE (measurable_tac' ctxt ss (Simplifier.prems_of ss));
   1.238 +  in try (fn () => Goal.prove ctxt [] [] t tac RS @{thm Eq_TrueI}) () end;
   1.239 +
   1.240 +end
   1.241 +