src/HOL/Statespace/state_fun.ML
author wenzelm
Sun Mar 08 17:26:14 2009 +0100 (2009-03-08)
changeset 30364 577edc39b501
parent 30289 b28caca9157f
child 30528 7173bf123335
permissions -rw-r--r--
moved basic algebra of long names from structure NameSpace to Long_Name;
     1 (*  Title:      HOL/Statespace/state_fun.ML
     2     Author:     Norbert Schirmer, TU Muenchen
     3 *)
     4 
     5 signature STATE_FUN =
     6 sig
     7   val lookupN : string
     8   val updateN : string
     9 
    10   val mk_constr : theory -> typ -> term
    11   val mk_destr : theory -> typ -> term
    12 
    13   val lookup_simproc : simproc
    14   val update_simproc : simproc
    15   val ex_lookup_eq_simproc : simproc
    16   val ex_lookup_ss : simpset
    17   val lazy_conj_simproc : simproc
    18   val string_eq_simp_tac : int -> tactic
    19 
    20   val setup : theory -> theory
    21 end;
    22 
    23 structure StateFun: STATE_FUN = 
    24 struct
    25 
    26 val lookupN = "StateFun.lookup";
    27 val updateN = "StateFun.update";
    28 
    29 val sel_name = HOLogic.dest_string;
    30 
    31 fun mk_name i t =
    32   (case try sel_name t of
    33      SOME name => name
    34    | NONE => (case t of 
    35                Free (x,_) => x
    36               |Const (x,_) => x
    37               |_ => "x"^string_of_int i))
    38                
    39 local
    40 val conj1_False = thm "conj1_False";
    41 val conj2_False = thm "conj2_False";
    42 val conj_True = thm "conj_True";
    43 val conj_cong = thm "conj_cong";
    44 fun isFalse (Const ("False",_)) = true
    45   | isFalse _ = false;
    46 fun isTrue (Const ("True",_)) = true
    47   | isTrue _ = false;
    48 
    49 in
    50 val lazy_conj_simproc =
    51   Simplifier.simproc @{theory HOL} "lazy_conj_simp" ["P & Q"]
    52     (fn thy => fn ss => fn t =>
    53       (case t of (Const ("op &",_)$P$Q) => 
    54          let
    55             val P_P' = Simplifier.rewrite ss (cterm_of thy P);
    56             val P' = P_P' |> prop_of |> Logic.dest_equals |> #2 
    57          in if isFalse P'
    58             then SOME (conj1_False OF [P_P'])
    59             else 
    60               let
    61                 val Q_Q' = Simplifier.rewrite ss (cterm_of thy Q);
    62                 val Q' = Q_Q' |> prop_of |> Logic.dest_equals |> #2 
    63               in if isFalse Q'
    64                  then SOME (conj2_False OF [Q_Q'])
    65                  else if isTrue P' andalso isTrue Q'
    66                       then SOME (conj_True OF [P_P', Q_Q'])
    67                       else if P aconv P' andalso Q aconv Q' then NONE
    68                            else SOME (conj_cong OF [P_P', Q_Q'])
    69               end 
    70          end
    71         
    72       | _ => NONE));
    73 
    74 val string_eq_simp_tac =
    75      simp_tac (HOL_basic_ss 
    76                  addsimps (thms "list.inject"@thms "char.inject"@simp_thms)
    77                  addsimprocs [DatatypePackage.distinct_simproc,lazy_conj_simproc]
    78                  addcongs [thm "block_conj_cong"])
    79 end;
    80 
    81 
    82 
    83 local
    84 val rules = 
    85  [thm "StateFun.lookup_update_id_same",
    86   thm "StateFun.id_id_cancel",
    87   thm "StateFun.lookup_update_same",thm "StateFun.lookup_update_other"
    88   ]
    89 in
    90 val lookup_ss = (HOL_basic_ss 
    91                  addsimps (thms "list.inject"@thms "char.inject"@simp_thms@rules)
    92                  addsimprocs [DatatypePackage.distinct_simproc,lazy_conj_simproc]
    93                  addcongs [thm "block_conj_cong"]
    94                  addSolver StateSpace.distinctNameSolver) 
    95 end;
    96 
    97 val ex_lookup_ss = HOL_ss addsimps [thm "StateFun.ex_id"];
    98 
    99 structure StateFunArgs =
   100 struct
   101   type T = (simpset * simpset * bool); 
   102            (* lookup simpset, ex_lookup simpset, are simprocs installed *)
   103   val empty = (empty_ss, empty_ss, false);
   104   val extend = I;
   105   fun merge pp ((ss1,ex_ss1,b1),(ss2,ex_ss2,b2)) = 
   106                (merge_ss (ss1,ss2)
   107                ,merge_ss (ex_ss1,ex_ss2)
   108                ,b1 orelse b2);
   109 end;
   110 
   111 
   112 structure StateFunData = GenericDataFun(StateFunArgs);
   113 
   114 val init_state_fun_data =
   115   Context.theory_map (StateFunData.put (lookup_ss,ex_lookup_ss,false));
   116 
   117 val lookup_simproc =
   118   Simplifier.simproc (the_context ()) "lookup_simp" ["lookup d n (update d' c m v s)"]
   119     (fn thy => fn ss => fn t =>
   120       (case t of (Const ("StateFun.lookup",lT)$destr$n$
   121                    (s as Const ("StateFun.update",uT)$_$_$_$_$_)) =>
   122         (let
   123           val (_::_::_::_::sT::_) = binder_types uT;
   124           val mi = maxidx_of_term t;
   125           fun mk_upds (Const ("StateFun.update",uT)$d'$c$m$v$s) =
   126                let
   127                  val (_::_::_::fT::_::_) = binder_types uT;
   128                  val vT = domain_type fT;
   129                  val (s',cnt) = mk_upds s;
   130                  val (v',cnt') = 
   131                       (case v of
   132                         Const ("StateFun.K_statefun",KT)$v''
   133                          => (case v'' of 
   134                              (Const ("StateFun.lookup",_)$(d as (Const ("Fun.id",_)))$n'$_)
   135                               => if d aconv c andalso n aconv m andalso m aconv n' 
   136                                  then (v,cnt) (* Keep value so that 
   137                                                  lookup_update_id_same can fire *)
   138                                  else (Const ("StateFun.K_statefun",KT)$Var (("v",cnt),vT),
   139                                        cnt+1)
   140                               | _ => (Const ("StateFun.K_statefun",KT)$Var (("v",cnt),vT),
   141                                        cnt+1))
   142                        | _ => (v,cnt));
   143                in (Const ("StateFun.update",uT)$d'$c$m$v'$s',cnt')
   144                end
   145             | mk_upds s = (Var (("s",mi+1),sT),mi+2);
   146           
   147           val ct = cterm_of thy 
   148                     (Const ("StateFun.lookup",lT)$destr$n$(fst (mk_upds s)));
   149           val ctxt = Simplifier.the_context ss;
   150           val basic_ss = #1 (StateFunData.get (Context.Proof ctxt));
   151           val ss' = Simplifier.context 
   152                      (Config.map MetaSimplifier.simp_depth_limit (K 100) ctxt) basic_ss;
   153           val thm = Simplifier.rewrite ss' ct;
   154         in if (op aconv) (Logic.dest_equals (prop_of thm))
   155            then NONE
   156            else SOME thm
   157         end
   158         handle Option => NONE)
   159          
   160       | _ => NONE ));
   161 
   162 
   163 fun foldl1 f (x::xs) = foldl f x xs;
   164 
   165 local
   166 val update_apply = thm "StateFun.update_apply";
   167 val meta_ext = thm "StateFun.meta_ext";
   168 val o_apply = thm "Fun.o_apply";
   169 val ss' = (HOL_ss addsimps (update_apply::o_apply::thms "list.inject"@thms "char.inject")
   170                  addsimprocs [DatatypePackage.distinct_simproc,lazy_conj_simproc,StateSpace.distinct_simproc]
   171                  addcongs [thm "block_conj_cong"]);
   172 in
   173 val update_simproc =
   174   Simplifier.simproc (the_context ()) "update_simp" ["update d c n v s"]
   175     (fn thy => fn ss => fn t =>
   176       (case t of ((upd as Const ("StateFun.update", uT)) $ d $ c $ n $ v $ s) =>
   177          let 
   178             
   179              val (_::_::_::_::sT::_) = binder_types uT;
   180                 (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => ('n => 'v) => ('n => 'v)"*)
   181              fun init_seed s = (Bound 0,Bound 0, [("s",sT)],[], false);
   182 
   183              fun mk_comp f fT g gT =
   184                let val T = (domain_type fT --> range_type gT) 
   185                in (Const ("Fun.comp",gT --> fT --> T)$g$f,T) end
   186 
   187              fun mk_comps fs = 
   188                    foldl1 (fn ((f,fT),(g,gT)) => mk_comp f fT g gT) fs;
   189 
   190              fun append n c cT f fT d dT comps =
   191                (case AList.lookup (op aconv) comps n of
   192                   SOME gTs => AList.update (op aconv) 
   193                                     (n,[(c,cT),(f,fT),(d,dT)]@gTs) comps
   194                 | NONE => AList.update (op aconv) (n,[(c,cT),(f,fT),(d,dT)]) comps)
   195 
   196              fun split_list (x::xs) = let val (xs',y) = split_last xs in (x,xs',y) end
   197                | split_list _ = error "StateFun.split_list";
   198 
   199              fun merge_upds n comps =
   200                let val ((c,cT),fs,(d,dT)) = split_list (the (AList.lookup (op aconv) comps n))
   201                in ((c,cT),fst (mk_comps fs),(d,dT)) end;
   202 
   203              (* mk_updterm returns 
   204               *  - (orig-term-skeleton,simplified-term-skeleton, vars, b)
   205               *     where boolean b tells if a simplification has occured.
   206                     "orig-term-skeleton = simplified-term-skeleton" is
   207               *     the desired simplification rule.
   208               * The algorithm first walks down the updates to the seed-state while
   209               * memorising the updates in the already-table. While walking up the
   210               * updates again, the optimised term is constructed.
   211               *)
   212              fun mk_updterm already
   213                  (t as ((upd as Const ("StateFun.update", uT)) $ d $ c $ n $ v $ s)) =
   214                       let
   215                          fun rest already = mk_updterm already;
   216                          val (dT::cT::nT::vT::sT::_) = binder_types uT;
   217                           (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => 
   218                                 ('n => 'v) => ('n => 'v)"*)
   219                       in if member (op aconv) already n
   220                          then (case rest already s of
   221                                  (trm,trm',vars,comps,_) =>
   222                                    let
   223                                      val i = length vars;
   224                                      val kv = (mk_name i n,vT);
   225                                      val kb = Bound i;
   226                                      val comps' = append n c cT kb vT d dT comps;
   227                                    in (upd$d$c$n$kb$trm, trm', kv::vars,comps',true) end)
   228                          else
   229                           (case rest (n::already) s of
   230                              (trm,trm',vars,comps,b) =>
   231                                 let
   232                                    val i = length vars;
   233                                    val kv = (mk_name i n,vT);
   234                                    val kb = Bound i;
   235                                    val comps' = append n c cT kb vT d dT comps;
   236                                    val ((c',c'T),f',(d',d'T)) = merge_upds n comps';
   237                                    val vT' = range_type d'T --> domain_type c'T;
   238                                    val upd' = Const ("StateFun.update",d'T --> c'T --> nT --> vT' --> sT --> sT);
   239                                 in (upd$d$c$n$kb$trm, upd'$d'$c'$n$f'$trm', kv::vars,comps',b) 
   240                                 end)
   241                       end
   242                | mk_updterm _ t = init_seed t;
   243 
   244 	     val ctxt = Simplifier.the_context ss |>
   245                         Config.map MetaSimplifier.simp_depth_limit (K 100);
   246              val ss1 = Simplifier.context ctxt ss';
   247              val ss2 = Simplifier.context ctxt 
   248                          (#1 (StateFunData.get (Context.Proof ctxt)));
   249          in (case mk_updterm [] t of
   250                (trm,trm',vars,_,true)
   251                 => let
   252                      val eq1 = Goal.prove ctxt [] [] 
   253                                       (list_all (vars, Logic.mk_equals (trm, trm')))
   254                                       (fn _ => rtac meta_ext 1 THEN 
   255                                                simp_tac ss1 1);
   256                      val eq2 = Simplifier.asm_full_rewrite ss2 (Thm.dest_equals_rhs (cprop_of eq1));
   257                    in SOME (transitive eq1 eq2) end
   258              | _ => NONE)
   259          end
   260        | _ => NONE));
   261 end
   262 
   263 
   264 
   265 
   266 local
   267 val swap_ex_eq = thm "StateFun.swap_ex_eq";
   268 fun is_selector thy T sel =
   269      let 
   270        val (flds,more) = RecordPackage.get_recT_fields thy T 
   271      in member (fn (s,(n,_)) => n=s) (more::flds) sel
   272      end;
   273 in
   274 val ex_lookup_eq_simproc =
   275   Simplifier.simproc @{theory HOL} "ex_lookup_eq_simproc" ["Ex t"]
   276     (fn thy => fn ss => fn t =>
   277        let
   278          val ctxt = Simplifier.the_context ss |>
   279                     Config.map MetaSimplifier.simp_depth_limit (K 100)
   280          val ex_lookup_ss = #2 (StateFunData.get (Context.Proof ctxt));
   281          val ss' = (Simplifier.context ctxt ex_lookup_ss);
   282          fun prove prop =
   283            Goal.prove_global thy [] [] prop 
   284              (fn _ => record_split_simp_tac [] (K ~1) 1 THEN
   285                       simp_tac ss' 1);
   286 
   287          fun mkeq (swap,Teq,lT,lo,d,n,x,s) i =
   288                let val (_::nT::_) = binder_types lT;
   289                          (*  ('v => 'a) => 'n => ('n => 'v) => 'a *)
   290                    val x' = if not (loose_bvar1 (x,0))
   291 			    then Bound 1
   292                             else raise TERM ("",[x]);
   293                    val n' = if not (loose_bvar1 (n,0))
   294 			    then Bound 2
   295                             else raise TERM ("",[n]);
   296                    val sel' = lo $ d $ n' $ s;
   297                   in (Const ("op =",Teq)$sel'$x',hd (binder_types Teq),nT,swap) end;
   298 
   299          fun dest_state (s as Bound 0) = s
   300            | dest_state (s as (Const (sel,sT)$Bound 0)) =
   301                if is_selector thy (domain_type sT) sel
   302                then s
   303                else raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector",[s])
   304            | dest_state s = 
   305                     raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector",[s]);
   306   
   307          fun dest_sel_eq (Const ("op =",Teq)$
   308                            ((lo as (Const ("StateFun.lookup",lT)))$d$n$s)$X) =
   309                            (false,Teq,lT,lo,d,n,X,dest_state s)
   310            | dest_sel_eq (Const ("op =",Teq)$X$
   311                             ((lo as (Const ("StateFun.lookup",lT)))$d$n$s)) =
   312                            (true,Teq,lT,lo,d,n,X,dest_state s)
   313            | dest_sel_eq _ = raise TERM ("",[]);
   314 
   315        in
   316          (case t of
   317            (Const ("Ex",Tex)$Abs(s,T,t)) =>
   318              (let val (eq,eT,nT,swap) = mkeq (dest_sel_eq t) 0;
   319                   val prop = list_all ([("n",nT),("x",eT)],
   320                               Logic.mk_equals (Const ("Ex",Tex)$Abs(s,T,eq),
   321                                                HOLogic.true_const));
   322                   val thm = standard (prove prop);
   323                   val thm' = if swap then swap_ex_eq OF [thm] else thm
   324              in SOME thm' end
   325              handle TERM _ => NONE)
   326           | _ => NONE)
   327         end handle Option => NONE) 
   328 end;
   329 
   330 val val_sfx = "V";
   331 val val_prfx = "StateFun."
   332 fun deco base_prfx s = val_prfx ^ (base_prfx ^ suffix val_sfx s);
   333 
   334 fun mkUpper str = 
   335   (case String.explode str of
   336     [] => ""
   337    | c::cs => String.implode (Char.toUpper c::cs ))
   338 
   339 fun mkName (Type (T,args)) = concat (map mkName args) ^ mkUpper (Long_Name.base_name T)
   340   | mkName (TFree (x,_)) = mkUpper (Long_Name.base_name x)
   341   | mkName (TVar ((x,_),_)) = mkUpper (Long_Name.base_name x);
   342 
   343 fun is_datatype thy n = is_some (Symtab.lookup (DatatypePackage.get_datatypes thy) n);
   344 
   345 fun mk_map "List.list" = Syntax.const "List.map"
   346   | mk_map n = Syntax.const ("StateFun.map_" ^ Long_Name.base_name n);
   347 
   348 fun gen_constr_destr comp prfx thy (Type (T,[])) = 
   349       Syntax.const (deco prfx (mkUpper (Long_Name.base_name T)))
   350   | gen_constr_destr comp prfx thy (T as Type ("fun",_)) =
   351      let val (argTs,rangeT) = strip_type T;
   352      in comp 
   353           (Syntax.const (deco prfx (concat (map mkName argTs) ^ "Fun")))
   354           (fold (fn x => fn y => x$y)
   355                (replicate (length argTs) (Syntax.const "StateFun.map_fun"))
   356                (gen_constr_destr comp prfx thy rangeT))
   357      end
   358   | gen_constr_destr comp prfx thy (T' as Type (T,argTs)) = 
   359      if is_datatype thy T
   360      then (* datatype args are recursively embedded into val *)
   361          (case argTs of
   362            [argT] => comp 
   363                      ((Syntax.const (deco prfx (mkUpper (Long_Name.base_name T)))))
   364                      ((mk_map T $ gen_constr_destr comp prfx thy argT))
   365           | _ => raise (TYPE ("StateFun.gen_constr_destr",[T'],[])))
   366      else (* type args are not recursively embedded into val *)
   367            Syntax.const (deco prfx (concat (map mkName argTs) ^ mkUpper (Long_Name.base_name T)))
   368   | gen_constr_destr thy _ _ T = raise (TYPE ("StateFun.gen_constr_destr",[T],[]));
   369                    
   370 val mk_constr = gen_constr_destr (fn a => fn b => Syntax.const "Fun.comp" $ a $ b) ""
   371 val mk_destr =  gen_constr_destr (fn a => fn b => Syntax.const "Fun.comp" $ b $ a) "the_"
   372 
   373   
   374 fun statefun_simp_attr src (ctxt,thm) = 
   375   let
   376      val (lookup_ss,ex_lookup_ss,simprocs_active) = StateFunData.get ctxt;
   377      val (lookup_ss', ex_lookup_ss') = 
   378 	   (case (concl_of thm) of
   379             (_$((Const ("Ex",_)$_))) => (lookup_ss, ex_lookup_ss addsimps [thm])
   380             | _ => (lookup_ss addsimps [thm], ex_lookup_ss))
   381      fun activate_simprocs ctxt =
   382           if simprocs_active then ctxt
   383           else Simplifier.map_ss (fn ss => ss addsimprocs [lookup_simproc,update_simproc]) ctxt
   384                
   385 
   386      val ctxt' = ctxt 
   387          |> activate_simprocs
   388          |> (StateFunData.put (lookup_ss',ex_lookup_ss',true))
   389   in (ctxt', thm) end;
   390 
   391 val setup = 
   392     init_state_fun_data 
   393     #> Attrib.add_attributes 
   394 	  [("statefun_simp",statefun_simp_attr,"simplification in statespaces")]     
   395 end