src/HOL/Statespace/state_fun.ML
author wenzelm
Mon Jun 23 23:45:39 2008 +0200 (2008-06-23)
changeset 27330 1af2598b5f7d
parent 27099 2a91d9575935
child 28308 d4396a28fb29
permissions -rw-r--r--
Logic.all/mk_equals/mk_implies;
     1 (*  Title:      state_fun.ML
     2     ID:         $Id$
     3     Author:     Norbert Schirmer, TU Muenchen
     4 *)
     5 
     6 signature STATE_FUN =
     7 sig
     8   val lookupN : string
     9   val updateN : string
    10 
    11   val mk_constr : Context.theory -> Term.typ -> Term.term
    12   val mk_destr : Context.theory -> Term.typ -> Term.term
    13 
    14   val lookup_simproc : MetaSimplifier.simproc
    15   val update_simproc : MetaSimplifier.simproc
    16   val ex_lookup_eq_simproc : MetaSimplifier.simproc
    17   val ex_lookup_ss : MetaSimplifier.simpset
    18   val lazy_conj_simproc : MetaSimplifier.simproc
    19   val string_eq_simp_tac : int -> Tactical.tactic
    20 
    21   val setup : Context.theory -> Context.theory
    22 end;
    23 
    24 structure StateFun: STATE_FUN = 
    25 struct
    26 
    27 val lookupN = "StateFun.lookup";
    28 val updateN = "StateFun.update";
    29 
    30 val sel_name = HOLogic.dest_string;
    31 
    32 fun mk_name i t =
    33   (case try sel_name t of
    34      SOME name => name
    35    | NONE => (case t of 
    36                Free (x,_) => x
    37               |Const (x,_) => x
    38               |_ => "x"^string_of_int i))
    39                
    40 local
    41 val conj1_False = thm "conj1_False";
    42 val conj2_False = thm "conj2_False";
    43 val conj_True = thm "conj_True";
    44 val conj_cong = thm "conj_cong";
    45 fun isFalse (Const ("False",_)) = true
    46   | isFalse _ = false;
    47 fun isTrue (Const ("True",_)) = true
    48   | isTrue _ = false;
    49 
    50 in
    51 val lazy_conj_simproc =
    52   Simplifier.simproc HOL.thy "lazy_conj_simp" ["P & Q"]
    53     (fn thy => fn ss => fn t =>
    54       (case t of (Const ("op &",_)$P$Q) => 
    55          let
    56             val P_P' = Simplifier.rewrite ss (cterm_of thy P);
    57             val P' = P_P' |> prop_of |> Logic.dest_equals |> #2 
    58          in if isFalse P'
    59             then SOME (conj1_False OF [P_P'])
    60             else 
    61               let
    62                 val Q_Q' = Simplifier.rewrite ss (cterm_of thy Q);
    63                 val Q' = Q_Q' |> prop_of |> Logic.dest_equals |> #2 
    64               in if isFalse Q'
    65                  then SOME (conj2_False OF [Q_Q'])
    66                  else if isTrue P' andalso isTrue Q'
    67                       then SOME (conj_True OF [P_P', Q_Q'])
    68                       else if P aconv P' andalso Q aconv Q' then NONE
    69                            else SOME (conj_cong OF [P_P', Q_Q'])
    70               end 
    71          end
    72         
    73       | _ => NONE));
    74 
    75 val string_eq_simp_tac =
    76      simp_tac (HOL_basic_ss 
    77                  addsimps (thms "list.inject"@thms "char.inject"@simp_thms)
    78                  addsimprocs [DatatypePackage.distinct_simproc,lazy_conj_simproc]
    79                  addcongs [thm "block_conj_cong"])
    80 end;
    81 
    82 
    83 
    84 local
    85 val rules = 
    86  [thm "StateFun.lookup_update_id_same",
    87   thm "StateFun.id_id_cancel",
    88   thm "StateFun.lookup_update_same",thm "StateFun.lookup_update_other"
    89   ]
    90 in
    91 val lookup_ss = (HOL_basic_ss 
    92                  addsimps (thms "list.inject"@thms "char.inject"@simp_thms@rules)
    93                  addsimprocs [DatatypePackage.distinct_simproc,lazy_conj_simproc]
    94                  addcongs [thm "block_conj_cong"]
    95                  addSolver StateSpace.distinctNameSolver) 
    96 end;
    97 
    98 val ex_lookup_ss = HOL_ss addsimps [thm "StateFun.ex_id"];
    99 
   100 structure StateFunArgs =
   101 struct
   102   type T = (simpset * simpset * bool); 
   103            (* lookup simpset, ex_lookup simpset, are simprocs installed *)
   104   val empty = (empty_ss, empty_ss, false);
   105   val extend = I;
   106   fun merge pp ((ss1,ex_ss1,b1),(ss2,ex_ss2,b2)) = 
   107                (merge_ss (ss1,ss2)
   108                ,merge_ss (ex_ss1,ex_ss2)
   109                ,b1 orelse b2);
   110 end;
   111 
   112 
   113 structure StateFunData = GenericDataFun(StateFunArgs);
   114 
   115 val init_state_fun_data =
   116   Context.theory_map (StateFunData.put (lookup_ss,ex_lookup_ss,false));
   117 
   118 val lookup_simproc =
   119   Simplifier.simproc (the_context ()) "lookup_simp" ["lookup d n (update d' c m v s)"]
   120     (fn thy => fn ss => fn t =>
   121       (case t of (Const ("StateFun.lookup",lT)$destr$n$
   122                    (s as Const ("StateFun.update",uT)$_$_$_$_$_)) =>
   123         (let
   124           val (_::_::_::_::sT::_) = binder_types uT;
   125           val mi = maxidx_of_term t;
   126           fun mk_upds (Const ("StateFun.update",uT)$d'$c$m$v$s) =
   127                let
   128                  val (_::_::_::fT::_::_) = binder_types uT;
   129                  val vT = domain_type fT;
   130                  val (s',cnt) = mk_upds s;
   131                  val (v',cnt') = 
   132                       (case v of
   133                         Const ("StateFun.K_statefun",KT)$v''
   134                          => (case v'' of 
   135                              (Const ("StateFun.lookup",_)$(d as (Const ("Fun.id",_)))$n'$_)
   136                               => if d aconv c andalso n aconv m andalso m aconv n' 
   137                                  then (v,cnt) (* Keep value so that 
   138                                                  lookup_update_id_same can fire *)
   139                                  else (Const ("StateFun.K_statefun",KT)$Var (("v",cnt),vT),
   140                                        cnt+1)
   141                               | _ => (Const ("StateFun.K_statefun",KT)$Var (("v",cnt),vT),
   142                                        cnt+1))
   143                        | _ => (v,cnt));
   144                in (Const ("StateFun.update",uT)$d'$c$m$v'$s',cnt')
   145                end
   146             | mk_upds s = (Var (("s",mi+1),sT),mi+2);
   147           
   148           val ct = cterm_of thy 
   149                     (Const ("StateFun.lookup",lT)$destr$n$(fst (mk_upds s)));
   150           val ctxt = the (#context (#1 (rep_ss ss)));
   151           val basic_ss = #1 (StateFunData.get (Context.Proof ctxt));
   152           val ss' = Simplifier.context 
   153                      (Config.map MetaSimplifier.simp_depth_limit (K 100) ctxt) basic_ss;
   154           val thm = Simplifier.rewrite ss' ct;
   155         in if (op aconv) (Logic.dest_equals (prop_of thm))
   156            then NONE
   157            else SOME thm
   158         end
   159         handle Option => NONE)
   160          
   161       | _ => NONE ));
   162 
   163 
   164 fun foldl1 f (x::xs) = foldl f x xs;
   165 
   166 local
   167 val update_apply = thm "StateFun.update_apply";
   168 val meta_ext = thm "StateFun.meta_ext";
   169 val o_apply = thm "Fun.o_apply";
   170 val ss' = (HOL_ss addsimps (update_apply::o_apply::thms "list.inject"@thms "char.inject")
   171                  addsimprocs [DatatypePackage.distinct_simproc,lazy_conj_simproc,StateSpace.distinct_simproc]
   172                  addcongs [thm "block_conj_cong"]);
   173 in
   174 val update_simproc =
   175   Simplifier.simproc (the_context ()) "update_simp" ["update d c n v s"]
   176     (fn thy => fn ss => fn t =>
   177       (case t of ((upd as Const ("StateFun.update", uT)) $ d $ c $ n $ v $ s) =>
   178          let 
   179             
   180              val (_::_::_::_::sT::_) = binder_types uT;
   181                 (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => ('n => 'v) => ('n => 'v)"*)
   182              fun init_seed s = (Bound 0,Bound 0, [("s",sT)],[], false);
   183 
   184              fun mk_comp f fT g gT =
   185                let val T = (domain_type fT --> range_type gT) 
   186                in (Const ("Fun.comp",gT --> fT --> T)$g$f,T) end
   187 
   188              fun mk_comps fs = 
   189                    foldl1 (fn ((f,fT),(g,gT)) => mk_comp f fT g gT) fs;
   190 
   191              fun append n c cT f fT d dT comps =
   192                (case AList.lookup (op aconv) comps n of
   193                   SOME gTs => AList.update (op aconv) 
   194                                     (n,[(c,cT),(f,fT),(d,dT)]@gTs) comps
   195                 | NONE => AList.update (op aconv) (n,[(c,cT),(f,fT),(d,dT)]) comps)
   196 
   197              fun split_list (x::xs) = let val (xs',y) = split_last xs in (x,xs',y) end
   198                | split_list _ = error "StateFun.split_list";
   199 
   200              fun merge_upds n comps =
   201                let val ((c,cT),fs,(d,dT)) = split_list (the (AList.lookup (op aconv) comps n))
   202                in ((c,cT),fst (mk_comps fs),(d,dT)) end;
   203 
   204              (* mk_updterm returns 
   205               *  - (orig-term-skeleton,simplified-term-skeleton, vars, b)
   206               *     where boolean b tells if a simplification has occured.
   207                     "orig-term-skeleton = simplified-term-skeleton" is
   208               *     the desired simplification rule.
   209               * The algorithm first walks down the updates to the seed-state while
   210               * memorising the updates in the already-table. While walking up the
   211               * updates again, the optimised term is constructed.
   212               *)
   213              fun mk_updterm already
   214                  (t as ((upd as Const ("StateFun.update", uT)) $ d $ c $ n $ v $ s)) =
   215                       let
   216                          fun rest already = mk_updterm already;
   217                          val (dT::cT::nT::vT::sT::_) = binder_types uT;
   218                           (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => 
   219                                 ('n => 'v) => ('n => 'v)"*)
   220                       in if member (op aconv) already n
   221                          then (case rest already s of
   222                                  (trm,trm',vars,comps,_) =>
   223                                    let
   224                                      val i = length vars;
   225                                      val kv = (mk_name i n,vT);
   226                                      val kb = Bound i;
   227                                      val comps' = append n c cT kb vT d dT comps;
   228                                    in (upd$d$c$n$kb$trm, trm', kv::vars,comps',true) end)
   229                          else
   230                           (case rest (n::already) s of
   231                              (trm,trm',vars,comps,b) =>
   232                                 let
   233                                    val i = length vars;
   234                                    val kv = (mk_name i n,vT);
   235                                    val kb = Bound i;
   236                                    val comps' = append n c cT kb vT d dT comps;
   237                                    val ((c',c'T),f',(d',d'T)) = merge_upds n comps';
   238                                    val vT' = range_type d'T --> domain_type c'T;
   239                                    val upd' = Const ("StateFun.update",d'T --> c'T --> nT --> vT' --> sT --> sT);
   240                                 in (upd$d$c$n$kb$trm, upd'$d'$c'$n$f'$trm', kv::vars,comps',b) 
   241                                 end)
   242                       end
   243                | mk_updterm _ t = init_seed t;
   244 
   245 	     val ctxt = the (#context (#1 (rep_ss ss))) |>
   246                         Config.map MetaSimplifier.simp_depth_limit (K 100);
   247              val ss1 = Simplifier.context ctxt ss';
   248              val ss2 = Simplifier.context ctxt 
   249                          (#1 (StateFunData.get (Context.Proof ctxt)));
   250          in (case mk_updterm [] t of
   251                (trm,trm',vars,_,true)
   252                 => let
   253                      val eq1 = Goal.prove ctxt [] [] 
   254                                       (list_all (vars, Logic.mk_equals (trm, trm')))
   255                                       (fn _ => rtac meta_ext 1 THEN 
   256                                                simp_tac ss1 1);
   257                      val eq2 = Simplifier.asm_full_rewrite ss2 (Thm.dest_equals_rhs (cprop_of eq1));
   258                    in SOME (transitive eq1 eq2) end
   259              | _ => NONE)
   260          end
   261        | _ => NONE));
   262 end
   263 
   264 
   265 
   266 
   267 local
   268 val swap_ex_eq = thm "StateFun.swap_ex_eq";
   269 fun is_selector thy T sel =
   270      let 
   271        val (flds,more) = RecordPackage.get_recT_fields thy T 
   272      in member (fn (s,(n,_)) => n=s) (more::flds) sel
   273      end;
   274 in
   275 val ex_lookup_eq_simproc =
   276   Simplifier.simproc HOL.thy "ex_lookup_eq_simproc" ["Ex t"]
   277     (fn thy => fn ss => fn t =>
   278        let
   279          val ctxt = Simplifier.the_context ss |>
   280                     Config.map MetaSimplifier.simp_depth_limit (K 100)
   281          val ex_lookup_ss = #2 (StateFunData.get (Context.Proof ctxt));
   282          val ss' = (Simplifier.context ctxt ex_lookup_ss);
   283          fun prove prop =
   284            Goal.prove_global thy [] [] prop 
   285              (fn _ => record_split_simp_tac [] (K ~1) 1 THEN
   286                       simp_tac ss' 1);
   287 
   288          fun mkeq (swap,Teq,lT,lo,d,n,x,s) i =
   289                let val (_::nT::_) = binder_types lT;
   290                          (*  ('v => 'a) => 'n => ('n => 'v) => 'a *)
   291                    val x' = if not (loose_bvar1 (x,0))
   292 			    then Bound 1
   293                             else raise TERM ("",[x]);
   294                    val n' = if not (loose_bvar1 (n,0))
   295 			    then Bound 2
   296                             else raise TERM ("",[n]);
   297                    val sel' = lo $ d $ n' $ s;
   298                   in (Const ("op =",Teq)$sel'$x',hd (binder_types Teq),nT,swap) end;
   299 
   300          fun dest_state (s as Bound 0) = s
   301            | dest_state (s as (Const (sel,sT)$Bound 0)) =
   302                if is_selector thy (domain_type sT) sel
   303                then s
   304                else raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector",[s])
   305            | dest_state s = 
   306                     raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector",[s]);
   307   
   308          fun dest_sel_eq (Const ("op =",Teq)$
   309                            ((lo as (Const ("StateFun.lookup",lT)))$d$n$s)$X) =
   310                            (false,Teq,lT,lo,d,n,X,dest_state s)
   311            | dest_sel_eq (Const ("op =",Teq)$X$
   312                             ((lo as (Const ("StateFun.lookup",lT)))$d$n$s)) =
   313                            (true,Teq,lT,lo,d,n,X,dest_state s)
   314            | dest_sel_eq _ = raise TERM ("",[]);
   315 
   316        in
   317          (case t of
   318            (Const ("Ex",Tex)$Abs(s,T,t)) =>
   319              (let val (eq,eT,nT,swap) = mkeq (dest_sel_eq t) 0;
   320                   val prop = list_all ([("n",nT),("x",eT)],
   321                               Logic.mk_equals (Const ("Ex",Tex)$Abs(s,T,eq),
   322                                                HOLogic.true_const));
   323                   val thm = standard (prove prop);
   324                   val thm' = if swap then swap_ex_eq OF [thm] else thm
   325              in SOME thm' end
   326              handle TERM _ => NONE)
   327           | _ => NONE)
   328         end handle Option => NONE) 
   329 end;
   330 
   331 val val_sfx = "V";
   332 val val_prfx = "StateFun."
   333 fun deco base_prfx s = val_prfx ^ (base_prfx ^ suffix val_sfx s);
   334 
   335 fun mkUpper str = 
   336   (case String.explode str of
   337     [] => ""
   338    | c::cs => String.implode (Char.toUpper c::cs ))
   339 
   340 fun mkName (Type (T,args)) = concat (map mkName args) ^ mkUpper (NameSpace.base T)
   341   | mkName (TFree (x,_)) = mkUpper (NameSpace.base x)
   342   | mkName (TVar ((x,_),_)) = mkUpper (NameSpace.base x);
   343 
   344 fun is_datatype thy n = is_some (Symtab.lookup (DatatypePackage.get_datatypes thy) n);
   345 
   346 fun mk_map ("List.list") = Syntax.const "List.map"
   347   | mk_map n = Syntax.const ("StateFun." ^  "map_" ^ NameSpace.base n);
   348 
   349 fun gen_constr_destr comp prfx thy (Type (T,[])) = 
   350       Syntax.const (deco prfx (mkUpper (NameSpace.base T)))
   351   | gen_constr_destr comp prfx thy (T as Type ("fun",_)) =
   352      let val (argTs,rangeT) = strip_type T;
   353      in comp 
   354           (Syntax.const (deco prfx (concat (map mkName argTs) ^ "Fun")))
   355           (fold (fn x => fn y => x$y)
   356                (replicate (length argTs) (Syntax.const "StateFun.map_fun"))
   357                (gen_constr_destr comp prfx thy rangeT))
   358      end
   359   | gen_constr_destr comp prfx thy (T' as Type (T,argTs)) = 
   360      if is_datatype thy T
   361      then (* datatype args are recursively embedded into val *)
   362          (case argTs of
   363            [argT] => comp 
   364                      ((Syntax.const (deco prfx (mkUpper (NameSpace.base T)))))
   365                      ((mk_map T $ gen_constr_destr comp prfx thy argT))
   366           | _ => raise (TYPE ("StateFun.gen_constr_destr",[T'],[])))
   367      else (* type args are not recursively embedded into val *)
   368            Syntax.const (deco prfx (concat (map mkName argTs) ^ mkUpper (NameSpace.base T)))
   369   | gen_constr_destr thy _ _ T = raise (TYPE ("StateFun.gen_constr_destr",[T],[]));
   370                    
   371 val mk_constr = gen_constr_destr (fn a => fn b => Syntax.const "Fun.comp" $ a $ b) ""
   372 val mk_destr =  gen_constr_destr (fn a => fn b => Syntax.const "Fun.comp" $ b $ a) "the_"
   373 
   374   
   375 fun statefun_simp_attr src (ctxt,thm) = 
   376   let
   377      val (lookup_ss,ex_lookup_ss,simprocs_active) = StateFunData.get ctxt;
   378      val (lookup_ss', ex_lookup_ss') = 
   379 	   (case (concl_of thm) of
   380             (_$((Const ("Ex",_)$_))) => (lookup_ss, ex_lookup_ss addsimps [thm])
   381             | _ => (lookup_ss addsimps [thm], ex_lookup_ss))
   382      fun activate_simprocs ctxt =
   383           if simprocs_active then ctxt
   384           else Simplifier.map_ss (fn ss => ss addsimprocs [lookup_simproc,update_simproc]) ctxt
   385                
   386 
   387      val ctxt' = ctxt 
   388          |> activate_simprocs
   389          |> (StateFunData.put (lookup_ss',ex_lookup_ss',true))
   390   in (ctxt', thm) end;
   391 
   392 val setup = 
   393     init_state_fun_data 
   394     #> Attrib.add_attributes 
   395 	  [("statefun_simp",statefun_simp_attr,"simplification in statespaces")]     
   396 end