src/HOL/Probability/measurable.ML
author haftmann
Tue Oct 13 09:21:15 2015 +0200 (2015-10-13)
changeset 61424 c3658c18b7bc
parent 60807 d7e6c7760db5
child 61877 276ad4354069
permissions -rw-r--r--
prod_case as canonical name for product type eliminator
     1 (*  Title:      HOL/Probability/measurable.ML
     2     Author:     Johannes Hölzl <hoelzl@in.tum.de>
     3 
     4 Measurability prover.
     5 *)
     6 
     7 signature MEASURABLE = 
     8 sig
     9   type preprocessor = thm -> Proof.context -> (thm list * Proof.context)
    10 
    11   datatype level = Concrete | Generic
    12 
    13   val dest_thm_attr : attribute context_parser
    14   val cong_thm_attr : attribute context_parser
    15   val measurable_thm_attr : bool * (bool * level) -> attribute
    16 
    17   val add_del_cong_thm : bool -> thm -> Context.generic -> Context.generic ;
    18 
    19   val get_all : Context.generic -> thm list
    20   val get_dest : Context.generic -> thm list
    21   val get_cong : Context.generic -> thm list
    22 
    23   val measurable_tac : Proof.context -> thm list -> tactic
    24 
    25   val simproc : Proof.context -> cterm -> thm option
    26 
    27   val add_preprocessor : string -> preprocessor -> Context.generic -> Context.generic
    28   val del_preprocessor : string -> Context.generic -> Context.generic
    29   val add_local_cong : thm -> Proof.context -> Proof.context
    30 
    31   val prepare_facts : Proof.context -> thm list -> (thm list * Proof.context)
    32 end ;
    33 
    34 structure Measurable : MEASURABLE =
    35 struct
    36 
    37 type preprocessor = thm -> Proof.context -> (thm list * Proof.context)
    38 
    39 datatype level = Concrete | Generic;
    40 
    41 fun eq_measurable_thms ((th1, d1), (th2, d2)) = 
    42   d1 = d2 andalso Thm.eq_thm_prop (th1, th2) ;
    43 
    44 fun merge_dups (xs:(string * preprocessor) list) ys =
    45   xs @ (filter (fn (name, _) => is_none (find_first (fn (name', _) => name' = name) xs)) ys) 
    46 
    47 structure Data = Generic_Data
    48 (
    49   type T = {
    50     measurable_thms : (thm * (bool * level)) Item_Net.T,
    51     dest_thms : thm Item_Net.T,
    52     cong_thms : thm Item_Net.T,
    53     preprocessors : (string * preprocessor) list }
    54   val empty: T = {
    55     measurable_thms = Item_Net.init eq_measurable_thms (single o Thm.prop_of o fst),
    56     dest_thms = Thm.full_rules,
    57     cong_thms = Thm.full_rules,
    58     preprocessors = [] };
    59   val extend = I;
    60   fun merge ({measurable_thms = t1, dest_thms = dt1, cong_thms = ct1, preprocessors = i1 },
    61       {measurable_thms = t2, dest_thms = dt2, cong_thms = ct2, preprocessors = i2 }) : T = {
    62     measurable_thms = Item_Net.merge (t1, t2),
    63     dest_thms = Item_Net.merge (dt1, dt2),
    64     cong_thms = Item_Net.merge (ct1, ct2),
    65     preprocessors = merge_dups i1 i2 
    66     };
    67 );
    68 
    69 val debug =
    70   Attrib.setup_config_bool @{binding measurable_debug} (K false)
    71 
    72 val split =
    73   Attrib.setup_config_bool @{binding measurable_split} (K true)
    74 
    75 fun map_data f1 f2 f3 f4
    76   {measurable_thms = t1,    dest_thms = t2,    cong_thms = t3,    preprocessors = t4 } =
    77   {measurable_thms = f1 t1, dest_thms = f2 t2, cong_thms = f3 t3, preprocessors = f4 t4}
    78 
    79 fun map_measurable_thms f = map_data f I I I
    80 fun map_dest_thms f = map_data I f I I
    81 fun map_cong_thms f = map_data I I f I
    82 fun map_preprocessors f = map_data I I I f
    83 
    84 fun generic_add_del map : attribute context_parser =
    85   Scan.lift
    86     (Args.add >> K Item_Net.update || Args.del >> K Item_Net.remove || Scan.succeed Item_Net.update) >>
    87     (fn f => Thm.declaration_attribute (Data.map o map o f))
    88 
    89 val dest_thm_attr = generic_add_del map_dest_thms
    90 
    91 val cong_thm_attr = generic_add_del map_cong_thms
    92 
    93 fun del_thm th net =
    94   let
    95     val thms = net |> Item_Net.content |> filter (fn (th', _) => Thm.eq_thm (th, th'))
    96   in fold Item_Net.remove thms net end ;
    97 
    98 fun measurable_thm_attr (do_add, d) = Thm.declaration_attribute
    99   (Data.map o map_measurable_thms o (if do_add then Item_Net.update o rpair d else del_thm))
   100 
   101 val get_dest = Item_Net.content o #dest_thms o Data.get;
   102 
   103 val get_cong = Item_Net.content o #cong_thms o Data.get;
   104 val add_cong = Data.map o map_cong_thms o Item_Net.update;
   105 val del_cong = Data.map o map_cong_thms o Item_Net.remove;
   106 fun add_del_cong_thm true = add_cong
   107   | add_del_cong_thm false = del_cong
   108 
   109 fun add_preprocessor name f = Data.map (map_preprocessors (fn xs => xs @ [(name, f)]))
   110 fun del_preprocessor name = Data.map (map_preprocessors (filter (fn (n, _) => n <> name)))
   111 val add_local_cong = Context.proof_map o add_cong
   112 
   113 val get_preprocessors = Context.Proof #> Data.get #> #preprocessors ;
   114 
   115 fun is_too_generic thm =
   116   let 
   117     val concl = Thm.concl_of thm
   118     val concl' = HOLogic.dest_Trueprop concl handle TERM _ => concl
   119   in is_Var (head_of concl') end
   120 
   121 val get_thms = Data.get #> #measurable_thms #> Item_Net.content ;
   122 
   123 val get_all = get_thms #> map fst ;
   124 
   125 fun debug_tac ctxt msg f =
   126   if Config.get ctxt debug then print_tac ctxt (msg ()) THEN f else f
   127 
   128 fun nth_hol_goal thm i =
   129   HOLogic.dest_Trueprop (Logic.strip_imp_concl (strip_all_body (nth (Thm.prems_of thm) (i - 1))))
   130 
   131 fun dest_measurable_fun t =
   132   (case t of
   133     (Const (@{const_name "Set.member"}, _) $ f $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => f
   134   | _ => raise (TERM ("not a measurability predicate", [t])))
   135 
   136 fun not_measurable_prop n thm =
   137   if length (Thm.prems_of thm) < n then false
   138   else
   139     (case nth_hol_goal thm n of
   140       (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "sets"}, _) $ _)) => false
   141     | (Const (@{const_name "Set.member"}, _) $ _ $ (Const (@{const_name "measurable"}, _) $ _ $ _)) => false
   142     | _ => true)
   143     handle TERM _ => true;
   144 
   145 fun indep (Bound i) t b = i < b orelse t <= i
   146   | indep (f $ t) top bot = indep f top bot andalso indep t top bot
   147   | indep (Abs (_,_,t)) top bot = indep t (top + 1) (bot + 1)
   148   | indep _ _ _ = true;
   149 
   150 fun cnt_prefixes ctxt (Abs (n, T, t)) =
   151     let
   152       fun is_countable ty = Sign.of_sort (Proof_Context.theory_of ctxt) (ty, @{sort countable})
   153       fun cnt_walk (Abs (ns, T, t)) Ts =
   154           map (fn (t', t'') => (Abs (ns, T, t'), t'')) (cnt_walk t (T::Ts))
   155         | cnt_walk (f $ g) Ts = let
   156             val n = length Ts - 1
   157           in
   158             map (fn (f', t) => (f' $ g, t)) (cnt_walk f Ts) @
   159             map (fn (g', t) => (f $ g', t)) (cnt_walk g Ts) @
   160             (if is_countable (type_of1 (Ts, g)) andalso loose_bvar1 (g, n)
   161                 andalso indep g n 0 andalso g <> Bound n
   162               then [(f $ Bound (n + 1), incr_boundvars (~ n) g)]
   163               else [])
   164           end
   165         | cnt_walk _ _ = []
   166     in map (fn (t1, t2) => let
   167         val T1 = type_of1 ([T], t2)
   168         val T2 = type_of1 ([T], t)
   169       in ([SOME (Abs (n, T1, Abs (n, T, t1))), NONE, NONE, SOME (Abs (n, T, t2))],
   170         [SOME T1, SOME T, SOME T2])
   171       end) (cnt_walk t [T])
   172     end
   173   | cnt_prefixes _ _ = []
   174 
   175 fun apply_dests thm dests =
   176   let
   177     fun apply thm th' =
   178       let
   179         val th'' = thm RS th'
   180       in [th''] @ loop th'' end
   181       handle (THM _) => []
   182     and loop thm =
   183       flat (map (apply thm) dests)
   184   in
   185     [thm] @ ([thm RS @{thm measurable_compose_rev}] handle (THM _) => []) @ loop thm
   186   end
   187 
   188 fun prepare_facts ctxt facts = 
   189   let
   190     val dests = get_dest (Context.Proof ctxt)
   191     fun prep_dest thm =
   192       (if is_too_generic thm then [] else apply_dests thm dests) ;
   193     val preprocessors = (("std", prep_dest #> pair) :: get_preprocessors ctxt) ;
   194     fun preprocess_thm (thm, raw) =
   195       if raw then pair [thm] else fold_map (fn (_, proc) => proc thm) preprocessors #>> flat
   196     
   197     fun sel lv (th, (raw, lv')) = if lv = lv' then SOME (th, raw) else NONE ;
   198     fun get lv = ctxt |> Context.Proof |> get_thms |> rev |> map_filter (sel lv) ;
   199     val pre_thms = map (Simplifier.norm_hhf ctxt #> rpair false) facts @ get Concrete @ get Generic
   200 
   201     val (thms, ctxt) = fold_map preprocess_thm pre_thms ctxt |>> flat
   202   in (thms, ctxt) end
   203 
   204 fun measurable_tac ctxt facts =
   205   let
   206     fun debug_fact msg thm () =
   207       msg ^ " " ^ Pretty.str_of (Syntax.pretty_term ctxt (Thm.prop_of thm))
   208 
   209     fun IF' c t i = COND (c i) (t i) no_tac
   210 
   211     fun r_tac msg =
   212       if Config.get ctxt debug
   213       then FIRST' o
   214         map (fn thm => resolve_tac ctxt [thm]
   215           THEN' K (debug_tac ctxt (debug_fact (msg ^ " resolved using") thm) all_tac))
   216       else resolve_tac ctxt
   217 
   218     val elem_congI = @{lemma "A = B \<Longrightarrow> x \<in> B \<Longrightarrow> x \<in> A" by simp}
   219 
   220     val (thms, ctxt) = prepare_facts ctxt facts
   221 
   222     fun is_sets_eq (Const (@{const_name "HOL.eq"}, _) $
   223           (Const (@{const_name "sets"}, _) $ _) $
   224           (Const (@{const_name "sets"}, _) $ _)) = true
   225       | is_sets_eq (Const (@{const_name "HOL.eq"}, _) $
   226           (Const (@{const_name "measurable"}, _) $ _ $ _) $
   227           (Const (@{const_name "measurable"}, _) $ _ $ _)) = true
   228       | is_sets_eq _ = false
   229 
   230     val cong_thms = get_cong (Context.Proof ctxt) @
   231       filter (fn thm => Thm.concl_of thm |> HOLogic.dest_Trueprop |> is_sets_eq handle TERM _ => false) facts
   232 
   233     fun sets_cong_tac i =
   234       Subgoal.FOCUS (fn {context = ctxt', prems = prems, ...} => (
   235         let
   236           val ctxt'' = Simplifier.add_prems prems ctxt'
   237         in
   238           r_tac "cong intro" [elem_congI]
   239           THEN' SOLVED' (fn i => REPEAT_DETERM (
   240               ((r_tac "cong solve" (cong_thms @ [@{thm refl}])
   241                 ORELSE' IF' (fn i => fn thm => Thm.nprems_of thm > i)
   242                   (SOLVED' (asm_full_simp_tac ctxt''))) i)))
   243         end) 1) ctxt i
   244         THEN flexflex_tac ctxt
   245 
   246     val simp_solver_tac = 
   247       IF' not_measurable_prop (debug_tac ctxt (K "simp ") o SOLVED' (asm_full_simp_tac ctxt))
   248 
   249     val split_countable_tac =
   250       Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
   251         let
   252           val f = dest_measurable_fun (HOLogic.dest_Trueprop t)
   253           fun inst (ts, Ts) =
   254             Thm.instantiate'
   255               (map (Option.map (Thm.ctyp_of ctxt)) Ts)
   256               (map (Option.map (Thm.cterm_of ctxt)) ts)
   257               @{thm measurable_compose_countable}
   258         in r_tac "case_prod countable" (cnt_prefixes ctxt f |> map inst) i end
   259         handle TERM _ => no_tac) 1)
   260 
   261     val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
   262 
   263     val single_step_tac =
   264       simp_solver_tac
   265       ORELSE' r_tac "step" thms
   266       ORELSE' splitter
   267       ORELSE' (CHANGED o sets_cong_tac)
   268       ORELSE' (K (debug_tac ctxt (K "backtrack") no_tac))
   269 
   270   in debug_tac ctxt (K "start") (REPEAT (single_step_tac 1)) end;
   271 
   272 fun simproc ctxt redex =
   273   let
   274     val t = HOLogic.mk_Trueprop (Thm.term_of redex);
   275     fun tac {context = ctxt, prems = _ } =
   276       SOLVE (measurable_tac ctxt (Simplifier.prems_of ctxt));
   277   in try (fn () => Goal.prove ctxt [] [] t tac RS @{thm Eq_TrueI}) () end;
   278 
   279 end
   280