src/HOL/Probability/measurable.ML
changeset 50387 3d8863c41fe8
child 51717 9e7d1c139569
equal deleted inserted replaced
50386:d00e2b0ca069 50387:3d8863c41fe8
       
     1 (*  Title:      HOL/Probability/measurable.ML
       
     2     Author:     Johannes Hölzl <hoelzl@in.tum.de>
       
     3 
       
     4 Measurability prover.
       
     5 *)
       
     6 
       
     7 signature MEASURABLE = 
       
     8 sig
       
     9   datatype level = Concrete | Generic
       
    10 
       
    11   val simproc : simpset -> cterm -> thm option
       
    12   val method : (Proof.context -> Method.method) context_parser
       
    13   val measurable_tac : Proof.context -> thm list -> tactic
       
    14 
       
    15   val attr : attribute context_parser
       
    16   val dest_attr : attribute context_parser
       
    17   val app_attr : attribute context_parser
       
    18 
       
    19   val get : level -> Proof.context -> thm list
       
    20   val get_all : Proof.context -> thm list
       
    21 
       
    22   val update : (thm Item_Net.T -> thm Item_Net.T) -> level -> Context.generic -> Context.generic
       
    23 
       
    24 end ;
       
    25 
       
    26 structure Measurable : MEASURABLE =
       
    27 struct
       
    28 
       
    29 datatype level = Concrete | Generic;
       
    30 
       
    31 structure Data = Generic_Data
       
    32 (
       
    33   type T = {
       
    34     concrete_thms : thm Item_Net.T,
       
    35     generic_thms : thm Item_Net.T,
       
    36     dest_thms : thm Item_Net.T,
       
    37     app_thms : thm Item_Net.T }
       
    38   val empty = {
       
    39     concrete_thms = Thm.full_rules,
       
    40     generic_thms = Thm.full_rules,
       
    41     dest_thms = Thm.full_rules,
       
    42     app_thms = Thm.full_rules};
       
    43   val extend = I;
       
    44   fun merge ({concrete_thms = ct1, generic_thms = gt1, dest_thms = dt1, app_thms = at1 },
       
    45       {concrete_thms = ct2, generic_thms = gt2, dest_thms = dt2, app_thms = at2 }) = {
       
    46     concrete_thms = Item_Net.merge (ct1, ct2),
       
    47     generic_thms = Item_Net.merge (gt1, gt2),
       
    48     dest_thms = Item_Net.merge (dt1, dt2),
       
    49     app_thms = Item_Net.merge (at1, at2) };
       
    50 );
       
    51 
       
    52 val debug =
       
    53   Attrib.setup_config_bool @{binding measurable_debug} (K false)
       
    54 
       
    55 val backtrack =
       
    56   Attrib.setup_config_int @{binding measurable_backtrack} (K 20)
       
    57 
       
    58 val split =
       
    59   Attrib.setup_config_bool @{binding measurable_split} (K true)
       
    60 
       
    61 fun TAKE n tac = Seq.take n o tac
       
    62 
       
    63 fun get lv =
       
    64   rev o Item_Net.content o (case lv of Concrete => #concrete_thms | Generic => #generic_thms) o
       
    65   Data.get o Context.Proof;
       
    66 
       
    67 fun get_all ctxt = get Concrete ctxt @ get Generic ctxt;
       
    68 
       
    69 fun map_data f1 f2 f3 f4
       
    70   {generic_thms = t1,    concrete_thms = t2,    dest_thms = t3,    app_thms = t4} =
       
    71   {generic_thms = f1 t1, concrete_thms = f2 t2, dest_thms = f3 t3, app_thms = f4 t4 }
       
    72 
       
    73 fun map_concrete_thms f = map_data f I I I
       
    74 fun map_generic_thms f = map_data I f I I
       
    75 fun map_dest_thms f = map_data I I f I
       
    76 fun map_app_thms f = map_data I I I f
       
    77 
       
    78 fun update f lv = Data.map (case lv of Concrete => map_concrete_thms f | Generic => map_generic_thms f);
       
    79 fun add thms' = update (fold Item_Net.update thms');
       
    80 
       
    81 val get_dest = Item_Net.content o #dest_thms o Data.get;
       
    82 val add_dest = Data.map o map_dest_thms o Item_Net.update;
       
    83 
       
    84 val get_app = Item_Net.content o #app_thms o Data.get;
       
    85 val add_app = Data.map o map_app_thms o Item_Net.update;
       
    86 
       
    87 fun is_too_generic thm =
       
    88   let 
       
    89     val concl = concl_of thm
       
    90     val concl' = HOLogic.dest_Trueprop concl handle TERM _ => concl
       
    91   in is_Var (head_of concl') end
       
    92 
       
    93 fun import_theorem ctxt thm = if is_too_generic thm then [] else
       
    94   [thm] @ map_filter (try (fn th' => thm RS th')) (get_dest ctxt);
       
    95 
       
    96 fun add_thm (raw, lv) thm ctxt = add (if raw then [thm] else import_theorem ctxt thm) lv ctxt;
       
    97 
       
    98 fun debug_tac ctxt msg f = if Config.get ctxt debug then print_tac (msg ()) THEN f else f
       
    99 
       
   100 fun nth_hol_goal thm i =
       
   101   HOLogic.dest_Trueprop (Logic.strip_imp_concl (strip_all_body (nth (prems_of thm) (i - 1))))
       
   102 
       
   103 fun dest_measurable_fun t =
       
   104   (case t of
       
   105     (Const (@{const_name "Set.member"}, _) $ f $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => f
       
   106   | _ => raise (TERM ("not a measurability predicate", [t])))
       
   107 
       
   108 fun is_cond_formula n thm = if length (prems_of thm) < n then false else
       
   109   (case nth_hol_goal thm n of
       
   110     (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "sets"}, _) $ _)) => false
       
   111   | (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => false
       
   112   | _ => true)
       
   113   handle TERM _ => true;
       
   114 
       
   115 fun indep (Bound i) t b = i < b orelse t <= i
       
   116   | indep (f $ t) top bot = indep f top bot andalso indep t top bot
       
   117   | indep (Abs (_,_,t)) top bot = indep t (top + 1) (bot + 1)
       
   118   | indep _ _ _ = true;
       
   119 
       
   120 fun cnt_prefixes ctxt (Abs (n, T, t)) = let
       
   121       fun is_countable t = Type.of_sort (Proof_Context.tsig_of ctxt) (t, @{sort countable})
       
   122       fun cnt_walk (Abs (ns, T, t)) Ts =
       
   123           map (fn (t', t'') => (Abs (ns, T, t'), t'')) (cnt_walk t (T::Ts))
       
   124         | cnt_walk (f $ g) Ts = let
       
   125             val n = length Ts - 1
       
   126           in
       
   127             map (fn (f', t) => (f' $ g, t)) (cnt_walk f Ts) @
       
   128             map (fn (g', t) => (f $ g', t)) (cnt_walk g Ts) @
       
   129             (if is_countable (type_of1 (Ts, g)) andalso loose_bvar1 (g, n)
       
   130                 andalso indep g n 0 andalso g <> Bound n
       
   131               then [(f $ Bound (n + 1), incr_boundvars (~ n) g)]
       
   132               else [])
       
   133           end
       
   134         | cnt_walk _ _ = []
       
   135     in map (fn (t1, t2) => let
       
   136         val T1 = type_of1 ([T], t2)
       
   137         val T2 = type_of1 ([T], t)
       
   138       in ([SOME (Abs (n, T1, Abs (n, T, t1))), NONE, NONE, SOME (Abs (n, T, t2))],
       
   139         [SOME T1, SOME T, SOME T2])
       
   140       end) (cnt_walk t [T])
       
   141     end
       
   142   | cnt_prefixes _ _ = []
       
   143 
       
   144 val split_countable_tac =
       
   145   Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
       
   146     let
       
   147       val f = dest_measurable_fun (HOLogic.dest_Trueprop t)
       
   148       fun cert f = map (Option.map (f (Proof_Context.theory_of ctxt)))
       
   149       fun inst t (ts, Ts) = Drule.instantiate' (cert ctyp_of Ts) (cert cterm_of ts) t
       
   150       val cps = cnt_prefixes ctxt f |> map (inst @{thm measurable_compose_countable})
       
   151     in if null cps then no_tac else debug_tac ctxt (K "split countable fun") (resolve_tac cps i) end
       
   152     handle TERM _ => no_tac) 1)
       
   153 
       
   154 fun measurable_tac' ctxt ss facts = let
       
   155 
       
   156     val imported_thms =
       
   157       (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf) facts) @ get_all ctxt
       
   158 
       
   159     fun debug_facts msg () =
       
   160       msg ^ " + " ^ Pretty.str_of (Pretty.list "[" "]"
       
   161         (map (Syntax.pretty_term ctxt o prop_of) (maps (import_theorem (Context.Proof ctxt)) facts)));
       
   162 
       
   163     val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
       
   164 
       
   165     val split_app_tac =
       
   166       Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
       
   167         let
       
   168           fun app_prefixes (Abs (n, T, (f $ g))) = let
       
   169                 val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
       
   170               in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
       
   171             | app_prefixes _ = []
       
   172 
       
   173           fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
       
   174             | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
       
   175           val thy = Proof_Context.theory_of ctxt
       
   176           val tunify = Sign.typ_unify thy
       
   177           val thms = map
       
   178               (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
       
   179               (get_app (Context.Proof ctxt))
       
   180           fun cert f = map (fn (t, t') => (f thy t, f thy t'))
       
   181           fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
       
   182             let
       
   183               val inst =
       
   184                 (Vartab.empty, ~1)
       
   185                 |> tunify (T, thmT)
       
   186                 |> tunify (Tf, thmTf)
       
   187                 |> tunify (Tc, thmTc)
       
   188                 |> Vartab.dest o fst
       
   189               val subst = subst_TVars (map (apsnd snd) inst)
       
   190             in
       
   191               Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
       
   192                 cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
       
   193             end
       
   194           val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
       
   195         in if null cps then no_tac
       
   196             else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
       
   197               ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
       
   198         handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
       
   199         handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
       
   200 
       
   201     fun REPEAT_cnt f n st = ((f n THEN REPEAT_cnt f (n + 1)) ORELSE all_tac) st
       
   202 
       
   203     val depth_measurable_tac = REPEAT_cnt (fn n =>
       
   204        (COND (is_cond_formula 1)
       
   205         (debug_tac ctxt (K ("simp " ^ string_of_int n)) (SOLVED' (asm_full_simp_tac ss) 1))
       
   206         ((debug_tac ctxt (K ("single " ^ string_of_int n)) (resolve_tac imported_thms 1)) APPEND
       
   207           (split_app_tac ctxt 1) APPEND
       
   208           (splitter 1)))) 0
       
   209 
       
   210   in debug_tac ctxt (debug_facts "start") depth_measurable_tac end;
       
   211 
       
   212 fun measurable_tac ctxt facts =
       
   213   TAKE (Config.get ctxt backtrack) (measurable_tac' ctxt (simpset_of ctxt) facts);
       
   214 
       
   215 val attr_add = Thm.declaration_attribute o add_thm;
       
   216 
       
   217 val attr : attribute context_parser =
       
   218   Scan.lift (Scan.optional (Args.parens (Scan.optional (Args.$$$ "raw" >> K true) false --
       
   219      Scan.optional (Args.$$$ "generic" >> K Generic) Concrete)) (false, Concrete) >> attr_add);
       
   220 
       
   221 val dest_attr : attribute context_parser =
       
   222   Scan.lift (Scan.succeed (Thm.declaration_attribute add_dest));
       
   223 
       
   224 val app_attr : attribute context_parser =
       
   225   Scan.lift (Scan.succeed (Thm.declaration_attribute add_app));
       
   226 
       
   227 val method : (Proof.context -> Method.method) context_parser =
       
   228   Scan.lift (Scan.succeed (fn ctxt => METHOD (fn facts => measurable_tac ctxt facts)));
       
   229 
       
   230 fun simproc ss redex = let
       
   231     val ctxt = Simplifier.the_context ss;
       
   232     val t = HOLogic.mk_Trueprop (term_of redex);
       
   233     fun tac {context = ctxt, prems = _ } =
       
   234       SOLVE (measurable_tac' ctxt ss (Simplifier.prems_of ss));
       
   235   in try (fn () => Goal.prove ctxt [] [] t tac RS @{thm Eq_TrueI}) () end;
       
   236 
       
   237 end
       
   238