src/HOL/Tools/record_package.ML
changeset 25705 45a2ffc5911e
parent 25179 b84f3c3c27f2
child 26065 d80a49f51b94
--- a/src/HOL/Tools/record_package.ML	Tue Dec 18 22:21:42 2007 +0100
+++ b/src/HOL/Tools/record_package.ML	Wed Dec 19 16:32:12 2007 +0100
@@ -34,7 +34,6 @@
   val makeN: string
   val moreN: string
   val ext_dest: string
-  val KN:string
 
   val last_extT: typ -> (string * typ list) option
   val dest_recTs : typ -> (string * typ list) list
@@ -64,8 +63,8 @@
 val meta_allE = thm "Pure.meta_allE";
 val prop_subst = thm "prop_subst";
 val Pair_sel_convs = [fst_conv,snd_conv];
-val K_record_apply = thm "Record.K_record_apply";
-val K_comp_convs = [o_apply,K_record_apply]
+val K_record_comp = thm "K_record_comp";
+val K_comp_convs = [o_apply, K_record_comp]
 
 (** name components **)
 
@@ -83,7 +82,6 @@
 val fields_selN = "fields";
 val extendN = "extend";
 val truncateN = "truncate";
-val KN = "Record.K_record";
 
 (*see typedef_package.ML*)
 val RepN = "Rep_";
@@ -526,7 +524,7 @@
 (* parse translations *)
 
 fun gen_field_tr mark sfx (t as Const (c, _) $ Const (name, _) $ arg) =
-      if c = mark then Syntax.const (suffix sfx name) $ (Syntax.const KN $ arg)
+      if c = mark then Syntax.const (suffix sfx name) $ (Abs ("_",dummyT, arg))
       else raise TERM ("gen_field_tr: " ^ mark, [t])
   | gen_field_tr mark _ t = raise TERM ("gen_field_tr: " ^ mark, [t]);
 
@@ -673,11 +671,21 @@
 val print_record_type_abbr = ref true;
 val print_record_type_as_fields = ref true;
 
-fun gen_field_upds_tr' mark sfx (tm as Const (name_field, _) $ (Const ("K_record",_)$t) $ u) =
+fun gen_field_upds_tr' mark sfx (tm as Const (name_field, _) $ k $ u) =
+  let val t = (case k of (Abs (_,_,(Abs (_,_,t)$Bound 0))) 
+                  => if null (loose_bnos t) then t else raise Match
+               | Abs (x,_,t) => if null (loose_bnos t) then t else raise Match
+               | _ => raise Match)
+
+      (* (case k of (Const ("K_record",_)$t) => t
+               | Abs (x,_,Const ("K_record",_)$t$Bound 0) => t
+               | _ => raise Match)*)
+  in
     (case try (unsuffix sfx) name_field of
       SOME name =>
         apfst (cons (Syntax.const mark $ Syntax.free name $ t)) (gen_field_upds_tr' mark sfx u)
      | NONE => ([], tm))
+  end
   | gen_field_upds_tr' _ _ tm = ([], tm);
 
 fun record_update_tr' tm =
@@ -966,6 +974,28 @@
 fun has_field extfields f T =
      exists (fn (eN,_) => exists (eq f o fst) (Symtab.lookup_list extfields eN))
        (dest_recTs T);
+
+fun K_skeleton n (T as Type (_,[_,kT])) (b as Bound i) (Abs (x,xT,t)) =
+     if null (loose_bnos t) then ((n,kT),(Abs (x,xT,Bound (i+1)))) else ((n,T),b)
+  | K_skeleton n T b _ = ((n,T),b);
+
+(*
+fun K_skeleton n _ b ((K_rec as Const ("Record.K_record",Type (_,[kT,_])))$_) = 
+      ((n,kT),K_rec$b)
+  | K_skeleton n _ (Bound i) 
+      (Abs (x,T,(K_rec as Const ("Record.K_record",Type (_,[kT,_])))$_$Bound 0)) =
+        ((n,kT),Abs (x,T,(K_rec$Bound (i+1)$Bound 0)))
+  | K_skeleton n T b  _ = ((n,T),b);
+ *)
+
+fun normalize_rhs thm =
+  let
+     val ss = HOL_basic_ss addsimps K_comp_convs; 
+     val rhs = thm |> Thm.cprop_of |> Thm.dest_comb |> snd;
+     val _ = tracing ("rhs:"^(Pretty.string_of (Display.pretty_cterm rhs)));
+     val rhs' = (Simplifier.rewrite ss rhs);
+     val _ = tracing ("rhs':"^(Pretty.string_of (Display.pretty_thm rhs')));
+  in Thm.transitive thm rhs' end;
 in
 (* record_simproc *)
 (* Simplifies selections of an record update:
@@ -999,13 +1029,11 @@
                                  let
                                    val rv = ("r",rT)
                                    val rb = Bound 0
-                                   val kv = ("k",kT)
-                                   val kb = Bound 1
+                                   val (kv,kb) = K_skeleton "k" kT (Bound 1) k;
                                   in SOME (upd$kb$rb,kb$(sel$rb),[kv,rv]) end
                               | SOME (trm,trm',vars) =>
                                  let
-                                   val kv = ("k",kT)
-                                   val kb = Bound (length vars)
+                                   val (kv,kb) = K_skeleton "k" kT (Bound (length vars)) k;
                                  in SOME (upd$kb$trm,kb$trm',kv::vars) end)
                         else if has_field extfields u_name rangeS
                              orelse has_field extfields s (domain_type kT)
@@ -1013,15 +1041,14 @@
                              else (case mk_eq_terms r of
                                      SOME (trm,trm',vars)
                                      => let
-                                          val kv = ("k",kT)
-                                          val kb = Bound (length vars)
+                                          val (kv,kb) = 
+                                                 K_skeleton "k" kT (Bound (length vars)) k;
                                         in SOME (upd$kb$trm,trm',kv::vars) end
                                    | NONE
                                      => let
                                           val rv = ("r",rT)
                                           val rb = Bound 0
-                                          val kv = ("k",kT)
-                                          val kb = Bound 1
+                                          val (kv,kb) = K_skeleton "k" kT (Bound 1) k;
                                         in SOME (upd$kb$rb,sel$rb,[kv,rv]) end))
                 | mk_eq_terms r = NONE
             in
@@ -1061,18 +1088,24 @@
 
              fun grow u uT k kT vars (sprout,skeleton) =
                    if sel_name u = moreN
-                   then let val kv = ("k", kT);
-                            val kb = Bound (length vars);
+                   then let val (kv,kb) = K_skeleton "k" kT (Bound (length vars)) k;
                         in ((Const (u,uT)$k$sprout,Const (u,uT)$kb$skeleton),kv::vars) end
                    else ((sprout,skeleton),vars);
 
-             fun is_upd_same (sprout,skeleton) u
-                                ((K_rec as Const ("Record.K_record",_))$
-                                  ((sel as Const (s,_))$r)) =
+
+             fun dest_k (Abs (x,T,((sel as Const (s,_))$r))) =
+                  if null (loose_bnos r) then SOME (x,T,sel,s,r) else NONE
+               | dest_k (Abs (_,_,(Abs (x,T,((sel as Const (s,_))$r)))$Bound 0)) =
+                  (* eta expanded variant *)
+                  if null (loose_bnos r) then SOME (x,T,sel,s,r) else NONE
+               | dest_k _ = NONE;
+
+             fun is_upd_same (sprout,skeleton) u k =
+               (case dest_k k of SOME (x,T,sel,s,r) =>
                    if (unsuffix updateN u) = s andalso (seed s sprout) = r
-                   then SOME (K_rec,sel,seed s skeleton)
+                   then SOME (fn t => Abs (x,T,incr_boundvars 1 t),sel,seed s skeleton)
                    else NONE
-               | is_upd_same _ _ _ = NONE
+                | NONE => NONE);
 
              fun init_seed r = ((r,Bound 0), [("r", rT)]);
 
@@ -1111,15 +1144,13 @@
                                  Init ((sprout,skel),vars) =>
                                  let
                                    val n = sel_name u;
-                                   val kv = (n, kT);
-                                   val kb = Bound (length vars);
+                                   val (kv,kb) = K_skeleton n kT (Bound (length vars)) k;
                                    val (sprout',vars')= grow u uT k kT (kv::vars) (sprout,skel);
                                  in Inter (upd$kb$skel,skel,vars',add n kb [],sprout') end
                                | Inter (trm,trm',vars,fmaps,sprout) =>
                                  let
                                    val n = sel_name u;
-                                   val kv = (n, kT);
-                                   val kb = Bound (length vars);
+                                   val (kv,kb) = K_skeleton n kT (Bound (length vars)) k;
                                    val (sprout',vars') = grow u uT k kT (kv::vars) sprout;
                                  in Inter(upd$kb$trm,trm',kv::vars',add n kb fmaps,sprout')
                                  end)
@@ -1130,26 +1161,25 @@
                                  SOME (K_rec,sel,skel') =>
                                  let
                                    val (sprout',vars') = grow u uT k kT vars (sprout,skel);
-                                  in Inter(upd$(K_rec$(sel$skel'))$skel,skel,vars',[],sprout')
+                                  in Inter(upd$(K_rec (sel$skel'))$skel,skel,vars',[],sprout')
                                   end
                                | NONE =>
                                  let
-                                   val kv = (sel_name u, kT);
-                                   val kb = Bound (length vars);
+                                   val n = sel_name u;
+                                   val (kv,kb) = K_skeleton n kT (Bound (length vars)) k;
                                  in Init ((upd$k$sprout,upd$kb$skel),kv::vars) end)
                            | Inter (trm,trm',vars,fmaps,sprout) =>
                                (case is_upd_same sprout u k of
                                   SOME (K_rec,sel,skel) =>
                                   let
                                     val (sprout',vars') = grow u uT k kT vars sprout
-                                  in Inter(upd$(K_rec$(sel$skel))$trm,trm',vars',fmaps,sprout')
+                                  in Inter(upd$(K_rec (sel$skel))$trm,trm',vars',fmaps,sprout')
                                   end
                                 | NONE =>
                                   let
                                     val n = sel_name u
                                     val T = domain_type kT
-                                    val kv = (n, kT)
-                                    val kb = Bound (length vars)
+                                    val (kv,kb) = K_skeleton n kT (Bound (length vars)) k;
                                     val (sprout',vars') = grow u uT k kT (kv::vars) sprout
                                     val fmaps' = add n kb fmaps
                                   in Inter (upd$kb$trm,upd$comps n T fmaps'$trm'
@@ -1160,8 +1190,9 @@
 
          in (case mk_updterm updates [] t of
                Inter (trm,trm',vars,_,_)
-                => SOME (prove_split_simp thy ss rT
-                          (list_all(vars,(equals rT$trm$trm'))))
+                => SOME (normalize_rhs 
+                          (prove_split_simp thy ss rT
+                            (list_all(vars,(equals rT$trm$trm')))))
              | _ => NONE)
          end
        | _ => NONE))