src/HOL/Statespace/state_fun.ML
author hoelzl
Tue Mar 26 12:20:58 2013 +0100 (2013-03-26)
changeset 51526 155263089e7b
parent 46218 ecf6375e2abb
child 51717 9e7d1c139569
permissions -rw-r--r--
move SEQ.thy and Lim.thy to Limits.thy
     1 (*  Title:      HOL/Statespace/state_fun.ML
     2     Author:     Norbert Schirmer, TU Muenchen
     3 *)
     4 
     5 signature STATE_FUN =
     6 sig
     7   val lookupN : string
     8   val updateN : string
     9 
    10   val mk_constr : theory -> typ -> term
    11   val mk_destr : theory -> typ -> term
    12 
    13   val lookup_simproc : simproc
    14   val update_simproc : simproc
    15   val ex_lookup_eq_simproc : simproc
    16   val ex_lookup_ss : simpset
    17   val lazy_conj_simproc : simproc
    18   val string_eq_simp_tac : int -> tactic
    19 
    20   val setup : theory -> theory
    21 end;
    22 
    23 structure StateFun: STATE_FUN =
    24 struct
    25 
    26 val lookupN = @{const_name StateFun.lookup};
    27 val updateN = @{const_name StateFun.update};
    28 
    29 val sel_name = HOLogic.dest_string;
    30 
    31 fun mk_name i t =
    32   (case try sel_name t of
    33     SOME name => name
    34   | NONE =>
    35       (case t of
    36         Free (x, _) => x
    37       | Const (x, _) => x
    38       | _ => "x" ^ string_of_int i));
    39 
    40 local
    41 
    42 val conj1_False = @{thm conj1_False};
    43 val conj2_False = @{thm conj2_False};
    44 val conj_True = @{thm conj_True};
    45 val conj_cong = @{thm conj_cong};
    46 
    47 fun isFalse (Const (@{const_name False}, _)) = true
    48   | isFalse _ = false;
    49 
    50 fun isTrue (Const (@{const_name True}, _)) = true
    51   | isTrue _ = false;
    52 
    53 in
    54 
    55 val lazy_conj_simproc =
    56   Simplifier.simproc_global @{theory HOL} "lazy_conj_simp" ["P & Q"]
    57     (fn thy => fn ss => fn t =>
    58       (case t of (Const (@{const_name HOL.conj},_) $ P $ Q) =>
    59         let
    60           val P_P' = Simplifier.rewrite ss (cterm_of thy P);
    61           val P' = P_P' |> prop_of |> Logic.dest_equals |> #2;
    62         in
    63           if isFalse P' then SOME (conj1_False OF [P_P'])
    64           else
    65             let
    66               val Q_Q' = Simplifier.rewrite ss (cterm_of thy Q);
    67               val Q' = Q_Q' |> prop_of |> Logic.dest_equals |> #2;
    68             in
    69               if isFalse Q' then SOME (conj2_False OF [Q_Q'])
    70               else if isTrue P' andalso isTrue Q' then SOME (conj_True OF [P_P', Q_Q'])
    71               else if P aconv P' andalso Q aconv Q' then NONE
    72               else SOME (conj_cong OF [P_P', Q_Q'])
    73             end
    74          end
    75       | _ => NONE));
    76 
    77 val string_eq_simp_tac = simp_tac (HOL_basic_ss
    78   addsimps (@{thms list.inject} @ @{thms char.inject}
    79     @ @{thms list.distinct} @ @{thms char.distinct} @ @{thms simp_thms})
    80   addsimprocs [lazy_conj_simproc]
    81   |> Simplifier.add_cong @{thm block_conj_cong});
    82 
    83 end;
    84 
    85 val lookup_ss = (HOL_basic_ss
    86   addsimps (@{thms list.inject} @ @{thms char.inject}
    87     @ @{thms list.distinct} @ @{thms char.distinct} @ @{thms simp_thms}
    88     @ [@{thm StateFun.lookup_update_id_same}, @{thm StateFun.id_id_cancel},
    89       @{thm StateFun.lookup_update_same}, @{thm StateFun.lookup_update_other}])
    90   addsimprocs [lazy_conj_simproc]
    91   addSolver StateSpace.distinctNameSolver
    92   |> fold Simplifier.add_cong @{thms block_conj_cong});
    93 
    94 val ex_lookup_ss = HOL_ss addsimps @{thms StateFun.ex_id};
    95 
    96 
    97 structure Data = Generic_Data
    98 (
    99   type T = simpset * simpset * bool;  (*lookup simpset, ex_lookup simpset, are simprocs installed*)
   100   val empty = (empty_ss, empty_ss, false);
   101   val extend = I;
   102   fun merge ((ss1, ex_ss1, b1), (ss2, ex_ss2, b2)) =
   103     (merge_ss (ss1, ss2), merge_ss (ex_ss1, ex_ss2), b1 orelse b2);
   104 );
   105 
   106 val init_state_fun_data =
   107   Context.theory_map (Data.put (lookup_ss, ex_lookup_ss, false));
   108 
   109 val lookup_simproc =
   110   Simplifier.simproc_global @{theory} "lookup_simp" ["lookup d n (update d' c m v s)"]
   111     (fn thy => fn ss => fn t =>
   112       (case t of (Const (@{const_name StateFun.lookup}, lT) $ destr $ n $
   113                    (s as Const (@{const_name StateFun.update}, uT) $ _ $ _ $ _ $ _ $ _)) =>
   114         (let
   115           val (_::_::_::_::sT::_) = binder_types uT;
   116           val mi = maxidx_of_term t;
   117           fun mk_upds (Const (@{const_name StateFun.update}, uT) $ d' $ c $ m $ v $ s) =
   118                 let
   119                   val (_ :: _ :: _ :: fT :: _ :: _) = binder_types uT;
   120                   val vT = domain_type fT;
   121                   val (s', cnt) = mk_upds s;
   122                   val (v', cnt') =
   123                     (case v of
   124                       Const (@{const_name K_statefun}, KT) $ v'' =>
   125                         (case v'' of
   126                           (Const (@{const_name StateFun.lookup}, _) $
   127                             (d as (Const (@{const_name Fun.id}, _))) $ n' $ _) =>
   128                               if d aconv c andalso n aconv m andalso m aconv n'
   129                               then (v,cnt) (* Keep value so that
   130                                               lookup_update_id_same can fire *)
   131                               else
   132                                 (Const (@{const_name StateFun.K_statefun}, KT) $
   133                                   Var (("v", cnt), vT), cnt + 1)
   134                         | _ =>
   135                           (Const (@{const_name StateFun.K_statefun}, KT) $
   136                             Var (("v", cnt), vT), cnt + 1))
   137                      | _ => (v, cnt));
   138                 in (Const (@{const_name StateFun.update}, uT) $ d' $ c $ m $ v' $ s', cnt') end
   139             | mk_upds s = (Var (("s", mi + 1), sT), mi + 2);
   140 
   141           val ct =
   142             cterm_of thy (Const (@{const_name StateFun.lookup}, lT) $ destr $ n $ fst (mk_upds s));
   143           val ctxt = Simplifier.the_context ss;
   144           val basic_ss = #1 (Data.get (Context.Proof ctxt));
   145           val ss' = Simplifier.context (Config.put simp_depth_limit 100 ctxt) basic_ss;
   146           val thm = Simplifier.rewrite ss' ct;
   147         in
   148           if (op aconv) (Logic.dest_equals (prop_of thm))
   149           then NONE
   150           else SOME thm
   151         end
   152         handle Option.Option => NONE)
   153       | _ => NONE ));
   154 
   155 
   156 local
   157 
   158 val meta_ext = @{thm StateFun.meta_ext};
   159 val ss' = (HOL_ss addsimps
   160   (@{thm StateFun.update_apply} :: @{thm Fun.o_apply} :: @{thms list.inject} @ @{thms char.inject}
   161     @ @{thms list.distinct} @ @{thms char.distinct})
   162   addsimprocs [lazy_conj_simproc, StateSpace.distinct_simproc]
   163   |> fold Simplifier.add_cong @{thms block_conj_cong});
   164 
   165 in
   166 
   167 val update_simproc =
   168   Simplifier.simproc_global @{theory} "update_simp" ["update d c n v s"]
   169     (fn thy => fn ss => fn t =>
   170       (case t of
   171         ((upd as Const (@{const_name StateFun.update}, uT)) $ d $ c $ n $ v $ s) =>
   172           let
   173             val (_ :: _ :: _ :: _ :: sT :: _) = binder_types uT;
   174               (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => ('n => 'v) => ('n => 'v)"*)
   175             fun init_seed s = (Bound 0, Bound 0, [("s", sT)], [], false);
   176 
   177             fun mk_comp f fT g gT =
   178               let val T = domain_type fT --> range_type gT
   179               in (Const (@{const_name Fun.comp}, gT --> fT --> T) $ g $ f, T) end;
   180 
   181             fun mk_comps fs = foldl1 (fn ((f, fT), (g, gT)) => mk_comp g gT f fT) fs;
   182 
   183             fun append n c cT f fT d dT comps =
   184               (case AList.lookup (op aconv) comps n of
   185                 SOME gTs => AList.update (op aconv) (n, [(c, cT), (f, fT), (d, dT)] @ gTs) comps
   186               | NONE => AList.update (op aconv) (n, [(c, cT), (f, fT), (d, dT)]) comps);
   187 
   188             fun split_list (x :: xs) = let val (xs', y) = split_last xs in (x, xs', y) end
   189               | split_list _ = error "StateFun.split_list";
   190 
   191             fun merge_upds n comps =
   192               let val ((c, cT), fs, (d, dT)) = split_list (the (AList.lookup (op aconv) comps n))
   193               in ((c, cT), fst (mk_comps fs), (d, dT)) end;
   194 
   195                (* mk_updterm returns
   196                 *  - (orig-term-skeleton,simplified-term-skeleton, vars, b)
   197                 *     where boolean b tells if a simplification has occurred.
   198                       "orig-term-skeleton = simplified-term-skeleton" is
   199                 *     the desired simplification rule.
   200                 * The algorithm first walks down the updates to the seed-state while
   201                 * memorising the updates in the already-table. While walking up the
   202                 * updates again, the optimised term is constructed.
   203                 *)
   204             fun mk_updterm already
   205                 (t as ((upd as Const (@{const_name StateFun.update}, uT)) $ d $ c $ n $ v $ s)) =
   206                   let
   207                     fun rest already = mk_updterm already;
   208                     val (dT :: cT :: nT :: vT :: sT :: _) = binder_types uT;
   209                       (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) =>
   210                             ('n => 'v) => ('n => 'v)"*)
   211                   in
   212                     if member (op aconv) already n then
   213                       (case rest already s of
   214                         (trm, trm', vars, comps, _) =>
   215                           let
   216                             val i = length vars;
   217                             val kv = (mk_name i n, vT);
   218                             val kb = Bound i;
   219                             val comps' = append n c cT kb vT d dT comps;
   220                           in (upd $ d $ c $ n $ kb $ trm, trm', kv :: vars, comps',true) end)
   221                     else
   222                       (case rest (n :: already) s of
   223                         (trm, trm', vars, comps, b) =>
   224                           let
   225                             val i = length vars;
   226                             val kv = (mk_name i n, vT);
   227                             val kb = Bound i;
   228                             val comps' = append n c cT kb vT d dT comps;
   229                             val ((c', c'T), f', (d', d'T)) = merge_upds n comps';
   230                             val vT' = range_type d'T --> domain_type c'T;
   231                             val upd' =
   232                               Const (@{const_name StateFun.update},
   233                                 d'T --> c'T --> nT --> vT' --> sT --> sT);
   234                           in
   235                             (upd $ d $ c $ n $ kb $ trm, upd' $ d' $ c' $ n $ f' $ trm', kv :: vars, comps', b)
   236                           end)
   237                   end
   238               | mk_updterm _ t = init_seed t;
   239 
   240             val ctxt = Simplifier.the_context ss |> Config.put simp_depth_limit 100;
   241             val ss1 = Simplifier.context ctxt ss';
   242             val ss2 = Simplifier.context ctxt (#1 (Data.get (Context.Proof ctxt)));
   243           in
   244             (case mk_updterm [] t of
   245               (trm, trm', vars, _, true) =>
   246                 let
   247                   val eq1 =
   248                     Goal.prove ctxt [] []
   249                       (Logic.list_all (vars, Logic.mk_equals (trm, trm')))
   250                       (fn _ => rtac meta_ext 1 THEN simp_tac ss1 1);
   251                   val eq2 = Simplifier.asm_full_rewrite ss2 (Thm.dest_equals_rhs (cprop_of eq1));
   252                 in SOME (Thm.transitive eq1 eq2) end
   253             | _ => NONE)
   254           end
   255       | _ => NONE));
   256 
   257 end;
   258 
   259 
   260 local
   261 
   262 val swap_ex_eq = @{thm StateFun.swap_ex_eq};
   263 
   264 fun is_selector thy T sel =
   265   let val (flds, more) = Record.get_recT_fields thy T
   266   in member (fn (s, (n, _)) => n = s) (more :: flds) sel end;
   267 
   268 in
   269 
   270 val ex_lookup_eq_simproc =
   271   Simplifier.simproc_global @{theory HOL} "ex_lookup_eq_simproc" ["Ex t"]
   272     (fn thy => fn ss => fn t =>
   273       let
   274         val ctxt = Simplifier.the_context ss |> Config.put simp_depth_limit 100;
   275         val ex_lookup_ss = #2 (Data.get (Context.Proof ctxt));
   276         val ss' = Simplifier.context ctxt ex_lookup_ss;
   277         fun prove prop =
   278           Goal.prove_global thy [] [] prop
   279             (fn _ => Record.split_simp_tac [] (K ~1) 1 THEN simp_tac ss' 1);
   280 
   281         fun mkeq (swap, Teq, lT, lo, d, n, x, s) i =
   282           let
   283             val (_ :: nT :: _) = binder_types lT;
   284             (*  ('v => 'a) => 'n => ('n => 'v) => 'a *)
   285             val x' = if not (Term.is_dependent x) then Bound 1 else raise TERM ("", [x]);
   286             val n' = if not (Term.is_dependent n) then Bound 2 else raise TERM ("", [n]);
   287             val sel' = lo $ d $ n' $ s;
   288           in (Const (@{const_name HOL.eq}, Teq) $ sel' $ x', hd (binder_types Teq), nT, swap) end;
   289 
   290         fun dest_state (s as Bound 0) = s
   291           | dest_state (s as (Const (sel, sT) $ Bound 0)) =
   292               if is_selector thy (domain_type sT) sel then s
   293               else raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector", [s])
   294           | dest_state s = raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector", [s]);
   295 
   296         fun dest_sel_eq
   297               (Const (@{const_name HOL.eq}, Teq) $
   298                 ((lo as (Const (@{const_name StateFun.lookup}, lT))) $ d $ n $ s) $ X) =
   299               (false, Teq, lT, lo, d, n, X, dest_state s)
   300           | dest_sel_eq
   301               (Const (@{const_name HOL.eq}, Teq) $ X $
   302                 ((lo as (Const (@{const_name StateFun.lookup}, lT))) $ d $ n $ s)) =
   303               (true, Teq, lT, lo, d, n, X, dest_state s)
   304           | dest_sel_eq _ = raise TERM ("", []);
   305       in
   306         (case t of
   307           Const (@{const_name Ex}, Tex) $ Abs (s, T, t) =>
   308             (let
   309               val (eq, eT, nT, swap) = mkeq (dest_sel_eq t) 0;
   310               val prop =
   311                 Logic.list_all ([("n", nT), ("x", eT)],
   312                   Logic.mk_equals (Const (@{const_name Ex}, Tex) $ Abs (s, T, eq), @{term True}));
   313               val thm = Drule.export_without_context (prove prop);
   314               val thm' = if swap then swap_ex_eq OF [thm] else thm
   315             in SOME thm' end handle TERM _ => NONE)
   316         | _ => NONE)
   317       end handle Option.Option => NONE);
   318 
   319 end;
   320 
   321 val val_sfx = "V";
   322 val val_prfx = "StateFun."
   323 fun deco base_prfx s = val_prfx ^ (base_prfx ^ suffix val_sfx s);
   324 
   325 fun mkUpper str =
   326   (case String.explode str of
   327     [] => ""
   328   | c::cs => String.implode (Char.toUpper c :: cs));
   329 
   330 fun mkName (Type (T,args)) = implode (map mkName args) ^ mkUpper (Long_Name.base_name T)
   331   | mkName (TFree (x,_)) = mkUpper (Long_Name.base_name x)
   332   | mkName (TVar ((x,_),_)) = mkUpper (Long_Name.base_name x);
   333 
   334 fun is_datatype thy = is_some o Datatype.get_info thy;
   335 
   336 fun mk_map "List.list" = Syntax.const "List.map"
   337   | mk_map n = Syntax.const ("StateFun.map_" ^ Long_Name.base_name n);
   338 
   339 fun gen_constr_destr comp prfx thy (Type (T, [])) =
   340       Syntax.const (deco prfx (mkUpper (Long_Name.base_name T)))
   341   | gen_constr_destr comp prfx thy (T as Type ("fun",_)) =
   342       let val (argTs, rangeT) = strip_type T;
   343       in
   344         comp
   345           (Syntax.const (deco prfx (implode (map mkName argTs) ^ "Fun")))
   346           (fold (fn x => fn y => x $ y)
   347             (replicate (length argTs) (Syntax.const "StateFun.map_fun"))
   348             (gen_constr_destr comp prfx thy rangeT))
   349       end
   350   | gen_constr_destr comp prfx thy (T' as Type (T, argTs)) =
   351       if is_datatype thy T
   352       then (* datatype args are recursively embedded into val *)
   353         (case argTs of
   354           [argT] =>
   355             comp
   356               ((Syntax.const (deco prfx (mkUpper (Long_Name.base_name T)))))
   357               ((mk_map T $ gen_constr_destr comp prfx thy argT))
   358         | _ => raise (TYPE ("StateFun.gen_constr_destr", [T'], [])))
   359       else (* type args are not recursively embedded into val *)
   360         Syntax.const (deco prfx (implode (map mkName argTs) ^ mkUpper (Long_Name.base_name T)))
   361   | gen_constr_destr thy _ _ T = raise (TYPE ("StateFun.gen_constr_destr", [T], []));
   362 
   363 val mk_constr = gen_constr_destr (fn a => fn b => Syntax.const @{const_name Fun.comp} $ a $ b) "";
   364 val mk_destr = gen_constr_destr (fn a => fn b => Syntax.const @{const_name Fun.comp} $ b $ a) "the_";
   365 
   366 
   367 val statefun_simp_attr = Thm.declaration_attribute (fn thm => fn ctxt =>
   368   let
   369     val (lookup_ss, ex_lookup_ss, simprocs_active) = Data.get ctxt;
   370     val (lookup_ss', ex_lookup_ss') =
   371       (case concl_of thm of
   372         (_ $ ((Const (@{const_name Ex}, _) $ _))) => (lookup_ss, ex_lookup_ss addsimps [thm])
   373       | _ => (lookup_ss addsimps [thm], ex_lookup_ss));
   374     fun activate_simprocs ctxt =
   375       if simprocs_active then ctxt
   376       else Simplifier.map_ss (fn ss => ss addsimprocs [lookup_simproc, update_simproc]) ctxt;
   377   in
   378     ctxt
   379     |> activate_simprocs
   380     |> Data.put (lookup_ss', ex_lookup_ss', true)
   381   end);
   382 
   383 val setup =
   384   init_state_fun_data #>
   385   Attrib.setup @{binding statefun_simp} (Scan.succeed statefun_simp_attr)
   386     "simplification in statespaces";
   387 
   388 end;