src/HOL/Statespace/state_fun.ML
changeset 25171 4a9c25bffc9b
child 25408 156f6f7082b8
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Statespace/state_fun.ML	Wed Oct 24 18:36:09 2007 +0200
@@ -0,0 +1,397 @@
+(*  Title:      state_fun.ML
+    ID:         $Id$
+    Author:     Norbert Schirmer, TU Muenchen
+*)
+
+
+structure StateFun =
+struct
+
+val lookupN = "StateFun.lookup";
+val updateN = "StateFun.update";
+
+
+fun dest_nib c =
+     let val h = List.last (String.explode c) 
+     in if #"0" <= h andalso h <= #"9" then Char.ord h - Char.ord #"0"
+        else if #"A" <= h andalso h <= #"F" then Char.ord h - Char.ord #"A" + 10
+        else raise Match
+     end;
+
+fun dest_chr (Const ("List.char.Char",_)$Const (c1,_)$(Const (c2,_))) = 
+    let val c = Char.chr (dest_nib c1 * 16 + dest_nib c2)
+    in if Char.isPrint c then c else raise Match end
+  | dest_chr _ = raise Match;
+
+fun dest_string (Const ("List.list.Nil",_)) = []
+  | dest_string (Const ("List.list.Cons",_)$c$cs) = dest_chr c::dest_string cs
+  | dest_string _ = raise TERM ("dest_string",[]);
+
+fun sel_name n = String.implode (dest_string n);
+
+fun mk_name i t =
+  (case try sel_name t of
+     SOME name => name
+   | NONE => (case t of 
+               Free (x,_) => x
+              |Const (x,_) => x
+              |_ => "x"^string_of_int i))
+               
+local
+val conj1_False = thm "conj1_False";
+val conj2_False = thm "conj2_False";
+val conj_True = thm "conj_True";
+val conj_cong = thm "conj_cong";
+fun isFalse (Const ("False",_)) = true
+  | isFalse _ = false;
+fun isTrue (Const ("True",_)) = true
+  | isTrue _ = false;
+
+in
+val lazy_conj_simproc =
+  Simplifier.simproc HOL.thy "lazy_conj_simp" ["P & Q"]
+    (fn thy => fn ss => fn t =>
+      (case t of (Const ("op &",_)$P$Q) => 
+         let
+            val P_P' = Simplifier.rewrite ss (cterm_of thy P);
+            val P' = P_P' |> prop_of |> Logic.dest_equals |> #2 
+         in if isFalse P'
+            then SOME (conj1_False OF [P_P'])
+            else 
+              let
+                val Q_Q' = Simplifier.rewrite ss (cterm_of thy Q);
+                val Q' = Q_Q' |> prop_of |> Logic.dest_equals |> #2 
+              in if isFalse Q'
+                 then SOME (conj2_False OF [Q_Q'])
+                 else if isTrue P' andalso isTrue Q'
+                      then SOME (conj_True OF [P_P', Q_Q'])
+                      else if P aconv P' andalso Q aconv Q' then NONE
+                           else SOME (conj_cong OF [P_P', Q_Q'])
+              end 
+         end
+        
+      | _ => NONE));
+
+val string_eq_simp_tac =
+     simp_tac (HOL_basic_ss 
+                 addsimps (thms "list.inject"@thms "char.inject"@simp_thms)
+                 addsimprocs [distinct_simproc,lazy_conj_simproc]
+                 addcongs [thm "block_conj_cong"])
+end;
+
+
+
+local
+val rules = 
+ [thm "StateFun.lookup_update_id_same",
+  thm "StateFun.id_id_cancel",
+  thm "StateFun.lookup_update_same",thm "StateFun.lookup_update_other"
+  ]
+in
+val lookup_ss = (HOL_basic_ss 
+                 addsimps (thms "list.inject"@thms "char.inject"@simp_thms@rules)
+                 addsimprocs [distinct_simproc,lazy_conj_simproc]
+                 addcongs [thm "block_conj_cong"]
+                 addSolver StateSpace.distinctNameSolver) 
+end;
+
+val ex_lookup_ss = HOL_ss addsimps [thm "StateFun.ex_id"];
+
+structure StateFunArgs =
+struct
+  type T = (simpset * simpset * bool); 
+           (* lookup simpset, ex_lookup simpset, are simprocs installed *)
+  val empty = (empty_ss, empty_ss, false);
+  val extend = I;
+  fun merge pp ((ss1,ex_ss1,b1),(ss2,ex_ss2,b2)) = 
+               (merge_ss (ss1,ss2)
+               ,merge_ss (ex_ss1,ex_ss2)
+               ,b1 orelse b2);
+end;
+
+
+structure StateFunData = GenericDataFun(StateFunArgs);
+
+val init_state_fun_data =
+  Context.theory_map (StateFunData.put (lookup_ss,ex_lookup_ss,false));
+
+val lookup_simproc =
+  Simplifier.simproc (the_context ()) "lookup_simp" ["lookup d n (update d' c m v s)"]
+    (fn thy => fn ss => fn t =>
+      (case t of (Const ("StateFun.lookup",lT)$destr$n$
+                   (s as Const ("StateFun.update",uT)$_$_$_$_$_)) =>
+        (let
+          val (_::_::_::_::sT::_) = binder_types uT;
+          val mi = maxidx_of_term t;
+          fun mk_upds (Const ("StateFun.update",uT)$d'$c$m$v$s) =
+               let
+                 val (_::_::_::fT::_::_) = binder_types uT;
+                 val vT = domain_type fT;
+                 val (s',cnt) = mk_upds s;
+                 val (v',cnt') = 
+                      (case v of
+                        Const ("StateFun.K_statefun",KT)$v''
+                         => (case v'' of 
+                             (Const ("StateFun.lookup",_)$(d as (Const ("Fun.id",_)))$n'$_)
+                              => if d aconv c andalso n aconv m andalso m aconv n' 
+                                 then (v,cnt) (* Keep value so that 
+                                                 lookup_update_id_same can fire *)
+                                 else (Const ("StateFun.K_statefun",KT)$Var (("v",cnt),vT),
+                                       cnt+1)
+                              | _ => (Const ("StateFun.K_statefun",KT)$Var (("v",cnt),vT),
+                                       cnt+1))
+                       | _ => (v,cnt));
+               in (Const ("StateFun.update",uT)$d'$c$m$v'$s',cnt')
+               end
+            | mk_upds s = (Var (("s",mi+1),sT),mi+2);
+          
+          val ct = cterm_of thy 
+                    (Const ("StateFun.lookup",lT)$destr$n$(fst (mk_upds s)));
+          val ctxt = the (#context (#1 (rep_ss ss)));
+          val basic_ss = #1 (StateFunData.get (Context.Proof ctxt));
+          val ss' = Simplifier.context 
+                     (Config.map MetaSimplifier.simp_depth_limit (K 100) ctxt) basic_ss;
+          val thm = Simplifier.rewrite ss' ct;
+        in if (op aconv) (Logic.dest_equals (prop_of thm))
+           then NONE
+           else SOME thm
+        end
+        handle Option => NONE)
+         
+      | _ => NONE ));
+
+
+fun foldl1 f (x::xs) = foldl f x xs;
+
+local
+val update_apply = thm "StateFun.update_apply";
+val meta_ext = thm "StateFun.meta_ext";
+val o_apply = thm "Fun.o_apply";
+val ss' = (HOL_ss addsimps (update_apply::o_apply::thms "list.inject"@thms "char.inject")
+                 addsimprocs [distinct_simproc,lazy_conj_simproc,StateSpace.distinct_simproc]
+                 addcongs [thm "block_conj_cong"]);
+in
+val update_simproc =
+  Simplifier.simproc (the_context ()) "update_simp" ["update d c n v s"]
+    (fn thy => fn ss => fn t =>
+      (case t of ((upd as Const ("StateFun.update", uT)) $ d $ c $ n $ v $ s) =>
+         let 
+            
+             val (_::_::_::_::sT::_) = binder_types uT;
+                (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => ('n => 'v) => ('n => 'v)"*)
+             fun init_seed s = (Bound 0,Bound 0, [("s",sT)],[], false);
+
+             fun mk_comp f fT g gT =
+               let val T = (domain_type fT --> range_type gT) 
+               in (Const ("Fun.comp",gT --> fT --> T)$g$f,T) end
+
+             fun mk_comps fs = 
+                   foldl1 (fn ((f,fT),(g,gT)) => mk_comp f fT g gT) fs;
+
+             fun append n c cT f fT d dT comps =
+               (case AList.lookup (op aconv) comps n of
+                  SOME gTs => AList.update (op aconv) 
+                                    (n,[(c,cT),(f,fT),(d,dT)]@gTs) comps
+                | NONE => AList.update (op aconv) (n,[(c,cT),(f,fT),(d,dT)]) comps)
+
+             fun split_list (x::xs) = let val (xs',y) = split_last xs in (x,xs',y) end
+               | split_list _ = error "StateFun.split_list";
+
+             fun merge_upds n comps =
+               let val ((c,cT),fs,(d,dT)) = split_list (the (AList.lookup (op aconv) comps n))
+               in ((c,cT),fst (mk_comps fs),(d,dT)) end;
+
+             (* mk_updterm returns 
+              *  - (orig-term-skeleton,simplified-term-skeleton, vars, b)
+              *     where boolean b tells if a simplification has occured.
+                    "orig-term-skeleton = simplified-term-skeleton" is
+              *     the desired simplification rule.
+              * The algorithm first walks down the updates to the seed-state while
+              * memorising the updates in the already-table. While walking up the
+              * updates again, the optimised term is constructed.
+              *)
+             fun mk_updterm already
+                 (t as ((upd as Const ("StateFun.update", uT)) $ d $ c $ n $ v $ s)) =
+                      let
+                         fun rest already = mk_updterm already;
+                         val (dT::cT::nT::vT::sT::_) = binder_types uT;
+                          (*"('v => 'a1) => ('a2 => 'v) => 'n => ('a1 => 'a2) => 
+                                ('n => 'v) => ('n => 'v)"*)
+                      in if member (op aconv) already n
+                         then (case rest already s of
+                                 (trm,trm',vars,comps,_) =>
+                                   let
+                                     val i = length vars;
+                                     val kv = (mk_name i n,vT);
+                                     val kb = Bound i;
+                                     val comps' = append n c cT kb vT d dT comps;
+                                   in (upd$d$c$n$kb$trm, trm', kv::vars,comps',true) end)
+                         else
+                          (case rest (n::already) s of
+                             (trm,trm',vars,comps,b) =>
+                                let
+                                   val i = length vars;
+                                   val kv = (mk_name i n,vT);
+                                   val kb = Bound i;
+                                   val comps' = append n c cT kb vT d dT comps;
+                                   val ((c',c'T),f',(d',d'T)) = merge_upds n comps';
+                                   val vT' = range_type d'T --> domain_type c'T;
+                                   val upd' = Const ("StateFun.update",d'T --> c'T --> nT --> vT' --> sT --> sT);
+                                in (upd$d$c$n$kb$trm, upd'$d'$c'$n$f'$trm', kv::vars,comps',b) 
+                                end)
+                      end
+               | mk_updterm _ t = init_seed t;
+
+	     val ctxt = the (#context (#1 (rep_ss ss))) |>
+                        Config.map MetaSimplifier.simp_depth_limit (K 100);
+             val ss1 = Simplifier.context ctxt ss';
+             val ss2 = Simplifier.context ctxt 
+                         (#1 (StateFunData.get (Context.Proof ctxt)));
+         in (case mk_updterm [] t of
+               (trm,trm',vars,_,true)
+                => let
+                     val eq1 = Goal.prove ctxt [] [] 
+                                      (list_all (vars,equals sT$trm$trm'))
+                                      (fn _ => rtac meta_ext 1 THEN 
+                                               simp_tac ss1 1);
+                     val eq2 = Simplifier.asm_full_rewrite ss2 (Thm.dest_equals_rhs (cprop_of eq1));
+                   in SOME (transitive eq1 eq2) end
+             | _ => NONE)
+         end
+       | _ => NONE));
+end
+
+
+
+
+local
+val swap_ex_eq = thm "StateFun.swap_ex_eq";
+fun is_selector thy T sel =
+     let 
+       val (flds,more) = RecordPackage.get_recT_fields thy T 
+     in member (fn (s,(n,_)) => n=s) (more::flds) sel
+     end;
+in
+val ex_lookup_eq_simproc =
+  Simplifier.simproc HOL.thy "ex_lookup_eq_simproc" ["Ex t"]
+    (fn thy => fn ss => fn t =>
+       let
+         val ctxt = Simplifier.the_context ss |>
+                    Config.map MetaSimplifier.simp_depth_limit (K 100)
+         val ex_lookup_ss = #2 (StateFunData.get (Context.Proof ctxt));
+         val ss' = (Simplifier.context ctxt ex_lookup_ss);
+         fun prove prop =
+           Goal.prove_global thy [] [] prop 
+             (fn _ => record_split_simp_tac [] (K ~1) 1 THEN
+                      simp_tac ss' 1);
+
+         fun mkeq (swap,Teq,lT,lo,d,n,x,s) i =
+               let val (_::nT::_) = binder_types lT;
+                         (*  ('v => 'a) => 'n => ('n => 'v) => 'a *)
+                   val x' = if not (loose_bvar1 (x,0))
+			    then Bound 1
+                            else raise TERM ("",[x]);
+                   val n' = if not (loose_bvar1 (n,0))
+			    then Bound 2
+                            else raise TERM ("",[n]);
+                   val sel' = lo $ d $ n' $ s;
+                  in (Const ("op =",Teq)$sel'$x',hd (binder_types Teq),nT,swap) end;
+
+         fun dest_state (s as Bound 0) = s
+           | dest_state (s as (Const (sel,sT)$Bound 0)) =
+               if is_selector thy (domain_type sT) sel
+               then s
+               else raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector",[s])
+           | dest_state s = 
+                    raise TERM ("StateFun.ex_lookup_eq_simproc: not a record slector",[s]);
+  
+         fun dest_sel_eq (Const ("op =",Teq)$
+                           ((lo as (Const ("StateFun.lookup",lT)))$d$n$s)$X) =
+                           (false,Teq,lT,lo,d,n,X,dest_state s)
+           | dest_sel_eq (Const ("op =",Teq)$X$
+                            ((lo as (Const ("StateFun.lookup",lT)))$d$n$s)) =
+                           (true,Teq,lT,lo,d,n,X,dest_state s)
+           | dest_sel_eq _ = raise TERM ("",[]);
+
+       in
+         (case t of
+           (Const ("Ex",Tex)$Abs(s,T,t)) =>
+             (let val (eq,eT,nT,swap) = mkeq (dest_sel_eq t) 0;
+                  val prop = list_all ([("n",nT),("x",eT)],
+                              Logic.mk_equals (Const ("Ex",Tex)$Abs(s,T,eq),
+                                               HOLogic.true_const));
+                  val thm = standard (prove prop);
+                  val thm' = if swap then swap_ex_eq OF [thm] else thm
+             in SOME thm' end
+             handle TERM _ => NONE)
+          | _ => NONE)
+        end handle Option => NONE) 
+end;
+
+val val_sfx = "V";
+val val_prfx = "StateFun."
+fun deco base_prfx s = val_prfx ^ (base_prfx ^ suffix val_sfx s);
+
+fun mkUpper str = 
+  (case String.explode str of
+    [] => ""
+   | c::cs => String.implode (Char.toUpper c::cs ))
+
+fun mkName (Type (T,args)) = concat (map mkName args) ^ mkUpper (NameSpace.base T)
+  | mkName (TFree (x,_)) = mkUpper (NameSpace.base x)
+  | mkName (TVar ((x,_),_)) = mkUpper (NameSpace.base x);
+
+fun is_datatype thy n = is_some (Symtab.lookup (DatatypePackage.get_datatypes thy) n);
+
+fun mk_map ("List.list") = Syntax.const "List.map"
+  | mk_map n = Syntax.const ("StateFun." ^  "map_" ^ NameSpace.base n);
+
+fun gen_constr_destr comp prfx thy (Type (T,[])) = 
+      Syntax.const (deco prfx (mkUpper (NameSpace.base T)))
+  | gen_constr_destr comp prfx thy (T as Type ("fun",_)) =
+     let val (argTs,rangeT) = strip_type T;
+     in comp 
+          (Syntax.const (deco prfx (concat (map mkName argTs) ^ "Fun")))
+          (fold (fn x => fn y => x$y)
+               (replicate (length argTs) (Syntax.const "StateFun.map_fun"))
+               (gen_constr_destr comp prfx thy rangeT))
+     end
+  | gen_constr_destr comp prfx thy (T' as Type (T,argTs)) = 
+     if is_datatype thy T
+     then (* datatype args are recursively embedded into val *)
+         (case argTs of
+           [argT] => comp 
+                     ((Syntax.const (deco prfx (mkUpper (NameSpace.base T)))))
+                     ((mk_map T $ gen_constr_destr comp prfx thy argT))
+          | _ => raise (TYPE ("StateFun.gen_constr_destr",[T'],[])))
+     else (* type args are not recursively embedded into val *)
+           Syntax.const (deco prfx (concat (map mkName argTs) ^ mkUpper (NameSpace.base T)))
+  | gen_constr_destr thy _ _ T = raise (TYPE ("StateFun.gen_constr_destr",[T],[]));
+                   
+val mk_constr = gen_constr_destr (fn a => fn b => Syntax.const "Fun.comp" $ a $ b) ""
+val mk_destr =  gen_constr_destr (fn a => fn b => Syntax.const "Fun.comp" $ b $ a) "the_"
+
+  
+fun statefun_simp_attr src (ctxt,thm) = 
+  let
+     val (lookup_ss,ex_lookup_ss,simprocs_active) = StateFunData.get ctxt;
+     val (lookup_ss', ex_lookup_ss') = 
+	   (case (concl_of thm) of
+            (_$((Const ("Ex",_)$_))) => (lookup_ss, ex_lookup_ss addsimps [thm])
+            | _ => (lookup_ss addsimps [thm], ex_lookup_ss))
+     fun activate_simprocs ctxt =
+          if simprocs_active then ctxt
+          else StateSpace.change_simpset 
+                (fn ss => ss addsimprocs [lookup_simproc,update_simproc]) ctxt
+               
+
+     val ctxt' = ctxt 
+         |> activate_simprocs
+         |> (StateFunData.put (lookup_ss',ex_lookup_ss',true))
+  in (ctxt', thm) end;
+
+val setup = 
+    init_state_fun_data 
+    #> Attrib.add_attributes 
+	  [("statefun_simp",statefun_simp_attr,"simplification in statespaces")]     
+end