improved thmtab
authorhaftmann
Thu, 17 Aug 2006 09:24:56 +0200
changeset 20394 21227c43ba26
parent 20393 df3252bbc0e6
child 20395 9a60e3151244
improved thmtab
src/Pure/Tools/codegen_theorems.ML
--- a/src/Pure/Tools/codegen_theorems.ML	Thu Aug 17 09:24:51 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Thu Aug 17 09:24:56 2006 +0200
@@ -24,6 +24,9 @@
   val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
   val preprocess: theory -> thm list -> thm list;
 
+  val prove_freeness: theory -> tactic -> string
+    -> (string * sort) list * (string * typ list) list -> thm list;
+
   type thmtab;
   val mk_thmtab: theory -> (string * typ) list -> thmtab;
   val get_sortalgebra: thmtab -> Sorts.algebra;
@@ -32,6 +35,7 @@
     -> ((string * sort) list * (string * typ list) list) option;
   val get_fun_thms: thmtab -> string * typ -> thm list;
 
+  val pretty_funtab: theory -> thm list CodegenConsts.Consttab.table -> Pretty.T;
   val print_thms: theory -> unit;
 
   val init_obj: (thm * thm) * (thm * thm) -> theory -> theory;
@@ -81,7 +85,7 @@
 
 fun init_obj ((TrueI, FalseE), (conjI, atomize_eq)) thy =
   case CodegenTheoremsSetup.get thy
-   of SOME _ => error "code generator already set up for object logic"
+   of SOME _ => error "Code generator already set up for object logic"
     | NONE =>
         let
           fun strip_implies t = (Logic.strip_imp_prems t, Logic.strip_imp_concl t);
@@ -114,7 +118,7 @@
                  #> apsnd (map Term.dest_Var)
                  #> apfst Term.dest_Const
                )
-            |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "wrong premise")
+            |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "Wrong premise")
           fun dest_atomize_eq thm=
             Drule.plain_prop_of thm
             |> Logic.dest_equals
@@ -130,10 +134,10 @@
                  #> apsnd Term.dest_Var
                )
             |> (fn (((eq, _), v2), (v1a as (_, TVar (_, sort)), v1b)) =>
-                 if [v1a, v1b] = v2 andalso sort = Sign.defaultS thy then eq else error "wrong premise")
+                 if [v1a, v1b] = v2 andalso sort = Sign.defaultS thy then eq else error "Wrong premise")
         in
           ((dest_TrueI TrueI, [dest_FalseE FalseE, dest_conjI conjI, dest_atomize_eq atomize_eq])
-          handle _ => error "bad code generator setup")
+          handle _ => error "Bad code generator setup")
           |> (fn ((tr, b), [fl, con, eq]) => CodegenTheoremsSetup.put
                (SOME ((b, atomize_eq), ((tr, fl), (con, eq)))) thy)
         end;
@@ -141,7 +145,7 @@
 fun get_obj thy =
   case CodegenTheoremsSetup.get thy
    of SOME ((b, atomize), x) => ((Type (b, []), atomize) ,x)
-    | NONE => error "no object logic setup for code theorems";
+    | NONE => error "No object logic setup for code theorems";
 
 fun mk_true thy =
   let
@@ -260,14 +264,14 @@
   case try (make_eq thy #> Drule.plain_prop_of
    #> ObjectLogic.drop_judgment thy #> Logic.dest_equals) thm
    of SOME eq => (eq, thm)
-    | NONE => err_thm "not an equation" thm;
+    | NONE => err_thm "Not an equation" thm;
 
 fun dest_fun thy thm =
   let
     fun dest_fun' ((lhs, _), thm) =
       case try (dest_Const o fst o strip_comb) lhs
        of SOME (c, ty) => (c, (ty, thm))
-        | NONE => err_thm "not a function equation" thm;
+        | NONE => err_thm "Not a function equation" thm;
   in
     thm
     |> dest_eq thy
@@ -280,21 +284,23 @@
 
 (* data structures *)
 
+structure Consttab = CodegenConsts.Consttab;
+
 fun merge' eq (xys as (xs, ys)) =
   if eq_list eq (xs, ys) then (false, xs) else (true, merge eq xys);
 
 fun alist_merge' eq_key eq (xys as (xs, ys)) =
   if eq_list (eq_pair eq_key eq) (xs, ys) then (false, xs) else (true, AList.merge eq_key eq xys);
 
-fun list_symtab_join' eq (xyt as (xt, yt)) =
+fun list_consttab_join' eq (xyt as (xt, yt)) =
   let
-    val xc = Symtab.keys xt;
-    val yc = Symtab.keys yt;
-    val zc = filter (member (op =) yc) xc;
+    val xc = Consttab.keys xt;
+    val yc = Consttab.keys yt;
+    val zc = filter (member CodegenConsts.eq_const yc) xc;
     val wc = subtract (op =) zc xc @ subtract (op =) zc yc;
-    fun same_thms c = if eq_list eq_thm ((the o Symtab.lookup xt) c, (the o Symtab.lookup yt) c)
+    fun same_thms c = if eq_list eq_thm ((the o Consttab.lookup xt) c, (the o Consttab.lookup yt) c)
       then NONE else SOME c;
-  in (wc @ map_filter same_thms zc, Symtab.join (K (merge eq)) xyt) end;
+  in (wc @ map_filter same_thms zc, Consttab.join (K (merge eq)) xyt) end;
 
 datatype notify = Notify of (serial * ((string * typ) list option -> theory -> theory)) list;
 
@@ -337,7 +343,7 @@
 
 datatype funthms = Funthms of {
   dirty: string list,
-  funs: thm list Symtab.table
+  funs: thm list Consttab.table
 };
 
 fun mk_funthms (dirty, funs) =
@@ -347,8 +353,8 @@
 fun merge_funthms _ (Funthms { dirty = dirty1, funs = funs1 },
   Funthms { dirty = dirty2, funs = funs2 }) =
     let
-      val (dirty3, funs) = list_symtab_join' eq_thm (funs1, funs2);
-    in mk_funthms (merge (op =) (merge (op =) (dirty1, dirty2), dirty3), funs) end;
+      val (dirty3, funs) = list_consttab_join' eq_thm (funs1, funs2);
+    in mk_funthms (merge (op =) (merge (op =) (dirty1, dirty2), map fst dirty3), funs) end;
 
 datatype T = T of {
   dirty: bool,
@@ -380,7 +386,7 @@
   val name = "Pure/codegen_theorems_data";
   type T = T;
   val empty = mk_T ((false, mk_notify []), (mk_preproc ([], []),
-    (mk_extrs ([], []), mk_funthms ([], Symtab.empty))));
+    (mk_extrs ([], []), mk_funthms ([], Consttab.empty))));
   val copy = I;
   val extend = I;
   val merge = merge_T;
@@ -388,7 +394,7 @@
     let
       val pretty_thm = ProofContext.pretty_thm (ProofContext.init thy);
       val funthms = (fn T { funthms, ... } => funthms) data;
-      val funs = (Symtab.dest o (fn Funthms { funs, ... } => funs)) funthms;
+      val funs = (Consttab.dest o (fn Funthms { funs, ... } => funs)) funthms;
       val preproc = (fn T { preproc, ... } => preproc) data;
       val unfolds = (fn Preproc { unfolds, ... } => unfolds) preproc;
     in
@@ -398,7 +404,7 @@
         (*Pretty.fbreaks ( *)
           map (fn (c, thms) =>
             (Pretty.block o Pretty.fbreaks) (
-              Pretty.str c :: map pretty_thm (rev thms)
+              (Pretty.str o CodegenConsts.string_of_const thy) c  :: map pretty_thm (rev thms)
             )
           ) funs
         (*) *) @ [
@@ -437,7 +443,9 @@
 (* notifiers *)
 
 fun all_typs thy c =
-  map (pair c) (Sign.the_const_type thy c :: (map (#lhs) o Theory.definitions_of thy) c);
+  let
+    val c_tys = (map (pair c o #lhs o snd) o Defs.specifications_of (Theory.defs_of thy)) c;
+  in (c, Sign.the_const_type thy c) :: map (CodegenConsts.typ_of_typinst thy) c_tys end;
 
 fun add_notify f =
   map_data (fn ((dirty, notify), x) =>
@@ -489,20 +497,20 @@
 
 fun add_fun thm thy =
   case dest_fun thy thm
-   of (c, _) =>
+   of (c, (ty, _)) =>
     thy
     |> map_data (fn (x, (preproc, (extrs, funthms))) =>
         (x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) =>
-          (dirty, funs |> Symtab.default (c, []) |> Symtab.map_entry c (cons thm)))))))
+          (dirty, funs |> Consttab.map_default (CodegenConsts.norminst_of_typ thy (c, ty), []) (cons thm)))))))
     |> notify_all (SOME c);
 
 fun del_fun thm thy =
   case dest_fun thy thm
-   of (c, _) =>
+   of (c, (ty, _)) =>
     thy
     |> map_data (fn (x, (preproc, (extrs, funthms))) =>
         (x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) =>
-          (dirty, funs |> Symtab.map_entry c (remove eq_thm thm)))))))
+          (dirty, funs |> Consttab.map_entry (CodegenConsts.norminst_of_typ thy (c, ty)) (remove eq_thm thm)))))))
     |> notify_all (SOME c);
 
 fun add_unfold thm thy =
@@ -523,9 +531,7 @@
   thy
   |> map_data (fn (x, (preproc, (extrs, funthms))) =>
       (x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) =>
-        (dirty, funs |> Symtab.map_entry c
-            (filter (fn thm => Sign.typ_instance thy
-              ((fst o snd o dest_fun thy) thm, ty)))))))))
+        (dirty, funs |> Consttab.update (CodegenConsts.norminst_of_typ thy (c, ty), [])))))))
   |> notify_all (SOME c);
 
 
@@ -556,7 +562,12 @@
           in (thm', max') end;
         val (thms', maxidx) = fold_map incr_thm thms 0;
         val (ty1::tys) = map extract_typ thms;
-        fun unify ty = Sign.typ_unify thy (ty1, ty);
+        fun unify ty env = Sign.typ_unify thy (ty1, ty) env
+          handle Type.TUNIFY =>
+            error ("Type unificaton failed, while unifying function equations\n"
+            ^ (cat_lines o map Display.string_of_thm) thms
+            ^ "\nwith types\n"
+            ^ (cat_lines o map (Sign.string_of_typ thy)) (ty1 :: tys));
         val (env, _) = fold unify tys (Vartab.empty, maxidx)
         val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
           cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
@@ -611,41 +622,47 @@
 fun get_funs thy (c, ty) =
   let
     val _ = debug_msg (fn _ => "[cg_thm] const (1) " ^ c ^ " :: " ^ Sign.string_of_typ thy ty) ()
-    val filter_typ = map_filter (fn (_, (ty', thm)) =>
-      if Sign.typ_instance thy (ty, ty')
-      then SOME thm else debug_msg (fn _ => "[cg_thm] dropping " ^ string_of_thm thm) NONE);
+    val postprocess_typ = case AxClass.class_of_param thy c
+     of NONE => map_filter (fn (_, (ty', thm)) =>
+          if Sign.typ_instance thy (ty, ty')
+          then SOME thm else debug_msg (fn _ => "[cg_thm] dropping " ^ string_of_thm thm) NONE)
+      | SOME _ => let
+          (*FIXME make this more elegant*)
+          val ty' = CodegenConsts.typ_of_classop thy (CodegenConsts.norminst_of_typ thy (c, ty));
+          val ct = Thm.cterm_of thy (Const (c, ty'));
+          val thm' = Thm.reflexive ct;
+        in map (snd o snd) #> cons thm' #> common_typ thy (extr_typ thy) #> tl end;
     fun get_funs (c, ty) =
-      (these o Symtab.lookup (the_funs thy)) c
+      (these o Consttab.lookup (the_funs thy) o CodegenConsts.norminst_of_typ thy) (c, ty)
       |> debug_msg (fn _ => "[cg_thm] trying funs")
       |> map (dest_fun thy)
-      |> filter_typ;
+      |> postprocess_typ;
     fun get_extr (c, ty) =
       getf_first_list (map (fn f => f thy) (the_funs_extrs thy)) (c, ty)
       |> debug_msg (fn _ => "[cg_thm] trying extr")
       |> map (dest_fun thy)
-      |> filter_typ;
+      |> postprocess_typ;
     fun get_spec (c, ty) =
-      Theory.definitions_of thy c
+      (CodegenConsts.find_def thy o CodegenConsts.norminst_of_typ thy) (c, ty)
       |> debug_msg (fn _ => "[cg_thm] trying spec")
-      (* FIXME avoid dynamic name space lookup!? (via Thm.get_axiom_i etc.??) *)
-      |> maps (fn { name, ... } => these (try (PureThy.get_thms thy) (Name name)))
+      |> Option.mapPartial (fn ((_, name), _) => try (Thm.get_axiom_i thy) name)
+      |> the_list
       |> map_filter (try (dest_fun thy))
-      |> filter_typ;
+      |> postprocess_typ;
   in
     getf_first_list [get_funs, get_extr, get_spec] (c, ty)
     |> debug_msg (fn _ => "[cg_thm] const (2) " ^ c ^ " :: " ^ Sign.string_of_typ thy ty)
     |> preprocess thy
   end;
 
-fun get_datatypes thy dtco =
+fun prove_freeness thy tac dtco vs_cos =
   let
-    val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) ()
     val truh = mk_true thy;
     val fals = mk_false thy;
     fun mk_lhs vs ((co1, tys1), (co2, tys2)) =
       let
         val dty = Type (dtco, map TFree vs);
-        val (xs1, xs2) = chop (length tys1) (Name.invent_list [] "x" (length tys1 + length tys2));
+        val (xs1, xs2) = chop (length tys1) (Name.invent_list [] "a" (length tys1 + length tys2));
         val frees1 = map2 (fn x => fn ty => Free (x, ty)) xs1 tys1;
         val frees2 = map2 (fn x => fn ty => Free (x, ty)) xs2 tys2;
         fun zip_co co xs tys = list_comb (Const (co,
@@ -667,13 +684,18 @@
     fun mk_eqs (vs, cos) =
       let val cos' = rev cos
       in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
-    fun mk_eq_thms tac vs_cos =
-      map (fn t => Goal.prove_global thy [] []
-        (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos);
+  in
+    map (fn t => Goal.prove_global thy [] []
+        (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos)
+  end;
+
+fun get_datatypes thy dtco =
+  let
+    val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) ()
   in
     case getf_first (map (fn f => f thy) (the_datatypes_extrs thy)) dtco
      of NONE => NONE
-      | SOME (vs_cos, tac) => SOME (vs_cos, mk_eq_thms tac vs_cos)
+      | SOME (vs_cos, tac) => SOME (vs_cos, prove_freeness thy tac dtco vs_cos)
   end;
 
 fun get_eq thy (c, ty) =
@@ -691,13 +713,13 @@
     fun check_head_lhs thm (lhs, rhs) =
       case strip_comb lhs
        of (Const (c', _), _) => if c' = c then ()
-           else error ("illegal function equation for " ^ quote c
+           else error ("Illegal function equation for " ^ quote c
              ^ ", actually defining " ^ quote c' ^ ": " ^ Display.string_of_thm thm)
-        | _ => error ("illegal function equation: " ^ Display.string_of_thm thm);
+        | _ => error ("Illegal function equation: " ^ Display.string_of_thm thm);
     fun check_vars_lhs thm (lhs, rhs) =
       if has_duplicates (op =)
           (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
-      then error ("repeated variables on left hand side of function equation:"
+      then error ("Repeated variables on left hand side of function equation:"
         ^ Display.string_of_thm thm)
       else ();
     fun check_vars_rhs thm (lhs, rhs) =
@@ -705,7 +727,7 @@
         (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
         (fold_aterms (fn Free (v, _) => cons v | _ => I) rhs []))
       then ()
-      else error ("free variables on right hand side of function equation:"
+      else error ("Free variables on right hand side of function equation:"
         ^ Display.string_of_thm thm)
     val tts = map (Logic.dest_equals o Logic.unvarify o Thm.prop_of) thms;
   in