src/HOL/Probability/measurable.ML
changeset 59047 8d7cec9b861d
parent 58965 a62cdcc5344b
child 59048 7dc8ac6f0895
     1.1 --- a/src/HOL/Probability/measurable.ML	Mon Nov 24 12:35:13 2014 +0100
     1.2 +++ b/src/HOL/Probability/measurable.ML	Mon Nov 24 12:20:35 2014 +0100
     1.3 @@ -8,20 +8,16 @@
     1.4  sig
     1.5    datatype level = Concrete | Generic
     1.6  
     1.7 -  val add_app : thm -> Context.generic -> Context.generic
     1.8 -  val add_dest : thm -> Context.generic -> Context.generic
     1.9 -  val add_thm : bool * level -> thm -> Context.generic -> Context.generic
    1.10 -  val del_thm : bool * level -> thm -> Context.generic -> Context.generic
    1.11 -  val add_del_thm : bool -> (bool * level) -> thm -> Context.generic -> Context.generic
    1.12 +  val app_thm_attr : attribute context_parser
    1.13 +  val dest_thm_attr : attribute context_parser
    1.14 +  val measurable_thm_attr : bool * (bool * level) -> attribute
    1.15  
    1.16    val measurable_tac : Proof.context -> thm list -> tactic
    1.17  
    1.18    val simproc : Proof.context -> cterm -> thm option
    1.19  
    1.20 -  val get : level -> Proof.context -> thm list
    1.21 +  val get_thms : Proof.context -> thm list
    1.22    val get_all : Proof.context -> thm list
    1.23 -
    1.24 -  val update : (thm Item_Net.T -> thm Item_Net.T) -> level -> Context.generic -> Context.generic
    1.25  end ;
    1.26  
    1.27  structure Measurable : MEASURABLE =
    1.28 @@ -29,23 +25,23 @@
    1.29  
    1.30  datatype level = Concrete | Generic;
    1.31  
    1.32 +fun eq_measurable_thms ((th1, d1), (th2, d2)) = 
    1.33 +  d1 = d2 andalso Thm.eq_thm_prop (th1, th2) ;
    1.34 +
    1.35  structure Data = Generic_Data
    1.36  (
    1.37    type T = {
    1.38 -    concrete_thms : thm Item_Net.T,
    1.39 -    generic_thms : thm Item_Net.T,
    1.40 +    measurable_thms : (thm * (bool * level)) Item_Net.T,
    1.41      dest_thms : thm Item_Net.T,
    1.42      app_thms : thm Item_Net.T }
    1.43    val empty = {
    1.44 -    concrete_thms = Thm.full_rules,
    1.45 -    generic_thms = Thm.full_rules,
    1.46 +    measurable_thms = Item_Net.init eq_measurable_thms (single o Thm.prop_of o fst),
    1.47      dest_thms = Thm.full_rules,
    1.48 -    app_thms = Thm.full_rules};
    1.49 +    app_thms = Thm.full_rules };
    1.50    val extend = I;
    1.51 -  fun merge ({concrete_thms = ct1, generic_thms = gt1, dest_thms = dt1, app_thms = at1 },
    1.52 -      {concrete_thms = ct2, generic_thms = gt2, dest_thms = dt2, app_thms = at2 }) = {
    1.53 -    concrete_thms = Item_Net.merge (ct1, ct2),
    1.54 -    generic_thms = Item_Net.merge (gt1, gt2),
    1.55 +  fun merge ({measurable_thms = t1, dest_thms = dt1, app_thms = at1 },
    1.56 +      {measurable_thms = t2, dest_thms = dt2, app_thms = at2 }) = {
    1.57 +    measurable_thms = Item_Net.merge (t1, t2),
    1.58      dest_thms = Item_Net.merge (dt1, dt2),
    1.59      app_thms = Item_Net.merge (at1, at2) };
    1.60  );
    1.61 @@ -53,38 +49,36 @@
    1.62  val debug =
    1.63    Attrib.setup_config_bool @{binding measurable_debug} (K false)
    1.64  
    1.65 -val backtrack =
    1.66 -  Attrib.setup_config_int @{binding measurable_backtrack} (K 20)
    1.67 -
    1.68  val split =
    1.69    Attrib.setup_config_bool @{binding measurable_split} (K true)
    1.70  
    1.71 -fun TAKE n tac = Seq.take n o tac
    1.72 +fun map_data f1 f2 f3
    1.73 +  {measurable_thms = t1,    dest_thms = t2,    app_thms = t3} =
    1.74 +  {measurable_thms = f1 t1, dest_thms = f2 t2, app_thms = f3 t3 }
    1.75  
    1.76 -fun get lv =
    1.77 -  rev o Item_Net.content o (case lv of Concrete => #concrete_thms | Generic => #generic_thms) o
    1.78 -  Data.get o Context.Proof;
    1.79 -
    1.80 -fun get_all ctxt = get Concrete ctxt @ get Generic ctxt;
    1.81 +fun map_measurable_thms f = map_data f I I
    1.82 +fun map_dest_thms f = map_data I f I
    1.83 +fun map_app_thms f = map_data I I f
    1.84  
    1.85 -fun map_data f1 f2 f3 f4
    1.86 -  {generic_thms = t1,    concrete_thms = t2,    dest_thms = t3,    app_thms = t4} =
    1.87 -  {generic_thms = f1 t1, concrete_thms = f2 t2, dest_thms = f3 t3, app_thms = f4 t4 }
    1.88 +fun generic_add_del map = 
    1.89 +  Scan.lift
    1.90 +    (Args.add >> K Item_Net.update || Args.del >> K Item_Net.remove || Scan.succeed Item_Net.update) >>
    1.91 +    (fn f => Thm.declaration_attribute (Data.map o map o f))
    1.92 +
    1.93 +val app_thm_attr = generic_add_del map_app_thms
    1.94  
    1.95 -fun map_concrete_thms f = map_data f I I I
    1.96 -fun map_generic_thms f = map_data I f I I
    1.97 -fun map_dest_thms f = map_data I I f I
    1.98 -fun map_app_thms f = map_data I I I f
    1.99 +val dest_thm_attr = generic_add_del map_dest_thms
   1.100  
   1.101 -fun update f lv = Data.map (case lv of Concrete => map_concrete_thms f | Generic => map_generic_thms f);
   1.102 -fun add thms' = update (fold Item_Net.update thms');
   1.103 -fun del thms' = update (fold Item_Net.remove thms');
   1.104 +fun del_thm th net =
   1.105 +  let
   1.106 +    val thms = net |> Item_Net.content |> filter (fn (th', _) => Thm.eq_thm (th, th'))
   1.107 +  in fold Item_Net.remove thms net end ;
   1.108 +
   1.109 +fun measurable_thm_attr (do_add, d) = Thm.declaration_attribute
   1.110 +  (Data.map o map_measurable_thms o (if do_add then Item_Net.update o rpair d else del_thm))
   1.111  
   1.112  val get_dest = Item_Net.content o #dest_thms o Data.get;
   1.113 -val add_dest = Data.map o map_dest_thms o Item_Net.update;
   1.114 -
   1.115  val get_app = Item_Net.content o #app_thms o Data.get;
   1.116 -val add_app = Data.map o map_app_thms o Item_Net.update;
   1.117  
   1.118  fun is_too_generic thm =
   1.119    let 
   1.120 @@ -95,12 +89,18 @@
   1.121  fun import_theorem ctxt thm = if is_too_generic thm then [] else
   1.122    [thm] @ map_filter (try (fn th' => thm RS th')) (get_dest ctxt);
   1.123  
   1.124 -fun add_del_thm_gen f (raw, lv) thm ctxt = f (if raw then [thm] else import_theorem ctxt thm) lv ctxt;
   1.125 +val get = Context.Proof #> Data.get #> #measurable_thms #> Item_Net.content ;
   1.126 +
   1.127 +val get_all = get #> map fst ;
   1.128  
   1.129 -val add_thm = add_del_thm_gen add;
   1.130 -val del_thm = add_del_thm_gen del;
   1.131 -fun add_del_thm true = add_thm
   1.132 -  | add_del_thm false = del_thm
   1.133 +fun get_thms ctxt =
   1.134 +  let
   1.135 +    val thms = ctxt |> get |> rev ;
   1.136 +    fun get lv = map_filter (fn (th, (rw, lv')) => if lv = lv' then SOME (th, rw) else NONE) thms
   1.137 +  in
   1.138 +    get Concrete @ get Generic |>
   1.139 +    maps (fn (th, rw) => if rw then [th] else import_theorem (Context.Proof ctxt) th)
   1.140 +  end;
   1.141  
   1.142  fun debug_tac ctxt msg f = if Config.get ctxt debug then print_tac ctxt (msg ()) THEN f else f
   1.143  
   1.144 @@ -158,10 +158,46 @@
   1.145      in if null cps then no_tac else debug_tac ctxt (K "split countable fun") (resolve_tac cps i) end
   1.146      handle TERM _ => no_tac) 1)
   1.147  
   1.148 -fun measurable_tac' ctxt facts =
   1.149 +val split_app_tac =
   1.150 +  Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
   1.151 +    let
   1.152 +      fun app_prefixes (Abs (n, T, (f $ g))) = let
   1.153 +            val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
   1.154 +          in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
   1.155 +        | app_prefixes _ = []
   1.156 +
   1.157 +      fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
   1.158 +        | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
   1.159 +      val thy = Proof_Context.theory_of ctxt
   1.160 +      val tunify = Sign.typ_unify thy
   1.161 +      val thms = map
   1.162 +          (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
   1.163 +          (get_app (Context.Proof ctxt))
   1.164 +      fun cert f = map (fn (t, t') => (f thy t, f thy t'))
   1.165 +      fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
   1.166 +        let
   1.167 +          val inst =
   1.168 +            (Vartab.empty, ~1)
   1.169 +            |> tunify (T, thmT)
   1.170 +            |> tunify (Tf, thmTf)
   1.171 +            |> tunify (Tc, thmTc)
   1.172 +            |> Vartab.dest o fst
   1.173 +          val subst = subst_TVars (map (apsnd snd) inst)
   1.174 +        in
   1.175 +          Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
   1.176 +            cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
   1.177 +        end
   1.178 +      val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
   1.179 +    in if null cps then no_tac
   1.180 +        else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
   1.181 +          ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
   1.182 +    handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
   1.183 +    handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
   1.184 +
   1.185 +fun measurable_tac ctxt facts =
   1.186    let
   1.187      val imported_thms =
   1.188 -      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf ctxt) facts) @ get_all ctxt
   1.189 +      (maps (import_theorem (Context.Proof ctxt) o Simplifier.norm_hhf ctxt) facts) @ get_thms ctxt
   1.190  
   1.191      fun debug_facts msg () =
   1.192        msg ^ " + " ^ Pretty.str_of (Pretty.list "[" "]"
   1.193 @@ -169,42 +205,6 @@
   1.194  
   1.195      val splitter = if Config.get ctxt split then split_countable_tac ctxt else K no_tac
   1.196  
   1.197 -    val split_app_tac =
   1.198 -      Subgoal.FOCUS (fn {context = ctxt, ...} => SUBGOAL (fn (t, i) =>
   1.199 -        let
   1.200 -          fun app_prefixes (Abs (n, T, (f $ g))) = let
   1.201 -                val ps = (if not (loose_bvar1 (g, 0)) then [(f, g)] else [])
   1.202 -              in map (fn (f, c) => (Abs (n, T, f), c, T, type_of c, type_of1 ([T], f $ c))) ps end
   1.203 -            | app_prefixes _ = []
   1.204 -
   1.205 -          fun dest_app (Abs (_, T, t as ((f $ Bound 0) $ c))) = (f, c, T, type_of c, type_of1 ([T], t))
   1.206 -            | dest_app t = raise (TERM ("not a measurability predicate of an application", [t]))
   1.207 -          val thy = Proof_Context.theory_of ctxt
   1.208 -          val tunify = Sign.typ_unify thy
   1.209 -          val thms = map
   1.210 -              (fn thm => (thm, dest_app (dest_measurable_fun (HOLogic.dest_Trueprop (concl_of thm)))))
   1.211 -              (get_app (Context.Proof ctxt))
   1.212 -          fun cert f = map (fn (t, t') => (f thy t, f thy t'))
   1.213 -          fun inst (f, c, T, Tc, Tf) (thm, (thmf, thmc, thmT, thmTc, thmTf)) =
   1.214 -            let
   1.215 -              val inst =
   1.216 -                (Vartab.empty, ~1)
   1.217 -                |> tunify (T, thmT)
   1.218 -                |> tunify (Tf, thmTf)
   1.219 -                |> tunify (Tc, thmTc)
   1.220 -                |> Vartab.dest o fst
   1.221 -              val subst = subst_TVars (map (apsnd snd) inst)
   1.222 -            in
   1.223 -              Thm.instantiate (cert ctyp_of (map (fn (n, (s, T)) => (TVar (n, s), T)) inst),
   1.224 -                cert cterm_of [(subst thmf, f), (subst thmc, c)]) thm
   1.225 -            end
   1.226 -          val cps = map_product inst (app_prefixes (dest_measurable_fun (HOLogic.dest_Trueprop t))) thms
   1.227 -        in if null cps then no_tac
   1.228 -            else debug_tac ctxt (K ("split app fun")) (resolve_tac cps i)
   1.229 -              ORELSE debug_tac ctxt (fn () => "FAILED") no_tac end
   1.230 -        handle TERM t => debug_tac ctxt (fn () => "TERM " ^ fst t ^ Pretty.str_of (Pretty.list "[" "]" (map (Syntax.pretty_term ctxt) (snd t)))) no_tac
   1.231 -        handle Type.TUNIFY => debug_tac ctxt (fn () => "TUNIFY") no_tac) 1)
   1.232 -
   1.233      fun REPEAT_cnt f n st = ((f n THEN REPEAT_cnt f (n + 1)) ORELSE all_tac) st
   1.234  
   1.235      val depth_measurable_tac = REPEAT_cnt (fn n =>
   1.236 @@ -216,14 +216,11 @@
   1.237  
   1.238    in debug_tac ctxt (debug_facts "start") depth_measurable_tac end;
   1.239  
   1.240 -fun measurable_tac ctxt facts =
   1.241 -  TAKE (Config.get ctxt backtrack) (measurable_tac' ctxt facts);
   1.242 -
   1.243  fun simproc ctxt redex =
   1.244    let
   1.245      val t = HOLogic.mk_Trueprop (term_of redex);
   1.246      fun tac {context = ctxt, prems = _ } =
   1.247 -      SOLVE (measurable_tac' ctxt (Simplifier.prems_of ctxt));
   1.248 +      SOLVE (measurable_tac ctxt (Simplifier.prems_of ctxt));
   1.249    in try (fn () => Goal.prove ctxt [] [] t tac RS @{thm Eq_TrueI}) () end;
   1.250  
   1.251  end