src/Pure/Tools/codegen_theorems.ML
changeset 20353 d73e49780ef2
parent 20218 be3bfb0699ba
child 20386 d1cbe5aa6bf2
--- a/src/Pure/Tools/codegen_theorems.ML	Tue Aug 08 08:19:18 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Tue Aug 08 08:19:30 2006 +0200
@@ -24,12 +24,11 @@
   val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
   val preprocess: theory -> thm list -> thm list;
 
-  val get_funs: theory -> string * typ -> thm list;
   val get_datatypes: theory -> string
     -> (((string * sort) list * (string * typ list) list) * thm list) option;
 
   type thmtab;
-  val mk_thmtab: (string * typ) list -> theory -> thmtab * theory;
+  val mk_thmtab: theory -> (string * typ) list -> thmtab;
   val norm_typ: theory -> string * typ -> string * typ;
   val get_sortalgebra: thmtab -> Sorts.algebra;
   val get_dtyp_of_cons: thmtab -> string * typ -> string option;
@@ -42,7 +41,8 @@
   val init_obj: (thm * thm) * (thm * thm) -> theory -> theory;
   val debug: bool ref;
   val debug_msg: ('a -> string) -> 'a -> 'a;
-
+  structure ConstTab: TABLE;
+  structure ConstGraph: GRAPH;
 end;
 
 structure CodegenTheorems: CODEGEN_THEOREMS =
@@ -412,7 +412,7 @@
         (*) *) @ [
         Pretty.fbrk,
         Pretty.block (
-          Pretty.str "unfolding theorems:"
+          Pretty.str "inlined theorems:"
           :: Pretty.fbrk
           :: (Pretty.fbreaks o map pretty_thm) unfolds
       )])
@@ -610,6 +610,7 @@
         #> Drule.zero_var_indexes
        )
     |> drop_redundant thy
+    |> debug_msg (fn _ => "[cg_thm] preprocessing done")
   end;
 
 
@@ -635,7 +636,7 @@
       Theory.definitions_of thy c
       |> debug_msg (fn _ => "[cg_thm] trying spec")
       (* FIXME avoid dynamic name space lookup!? (via Thm.get_axiom_i etc.??) *)
-      |> maps (PureThy.get_thms thy o Name o #name)
+      |> maps (fn { name, ... } => these (try (PureThy.get_thms thy) (Name name)))
       |> map_filter (try (dest_fun thy))
       |> filter_typ;
   in
@@ -693,6 +694,33 @@
     | _ => []
   else [];
 
+fun check_thms c thms =
+  let
+    fun check_head_lhs thm (lhs, rhs) =
+      case strip_comb lhs
+       of (Const (c', _), _) => if c' = c then ()
+           else error ("illegal function equation for " ^ quote c
+             ^ ", actually defining " ^ quote c' ^ ": " ^ Display.string_of_thm thm)
+        | _ => error ("illegal function equation: " ^ Display.string_of_thm thm);
+    fun check_vars_lhs thm (lhs, rhs) =
+      if has_duplicates (op =)
+          (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
+      then error ("repeated variables on left hand side of function equation:"
+        ^ Display.string_of_thm thm)
+      else ();
+    fun check_vars_rhs thm (lhs, rhs) =
+      if null (subtract (op =)
+        (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
+        (fold_aterms (fn Free (v, _) => cons v | _ => I) rhs []))
+      then ()
+      else error ("free variables on right hand side of function equation:"
+        ^ Display.string_of_thm thm)
+    val tts = map (Logic.dest_equals o Logic.unvarify o Thm.prop_of) thms;
+  in
+    (map2 check_head_lhs thms tts; map2 check_vars_lhs thms tts;
+      map2 check_vars_rhs thms tts; thms)
+  end;
+
 structure ConstTab = TableFun(type key = string * typ val ord = prod_ord fast_string_ord Term.typ_ord);
 structure ConstGraph = GraphFun(type key = string * typ val ord = prod_ord fast_string_ord Term.typ_ord);
 
@@ -729,14 +757,10 @@
 fun get_fun_thms (thy, (fungr, _), _) (c, ty) =
   let
     val (_, ty') = norm_typ thy (c, ty);
-  in these (try (ConstGraph.get_node fungr) (c, ty')) end;
+  in these (try (ConstGraph.get_node fungr) (c, ty')) |> check_thms c end;
 
-fun mk_thmtab' thy cs =
+fun mk_thmtab thy cs =
   let
-    fun get_dtco_candidate ty =
-      case strip_type ty
-       of (_, Type (tyco, _)) => SOME tyco
-        | _ => NONE;
     fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
       | add_tycos _ = I;
     val add_consts = fold_aterms
@@ -745,7 +769,7 @@
     fun add_dtyps_of_type ty thmtab =
       let
         val tycos = add_tycos ty [];
-        val tycos_new = filter (is_some o get_dtyp_spec thmtab) tycos;
+        val tycos_new = filter (is_none o get_dtyp_spec thmtab) tycos;
         fun add_dtyp_spec dtco (dtyp_spec as (vs, cs)) ((thy, (fungr, dtcotab), (algebra, dttab))) =
           let
             fun mk_co (c, tys) =
@@ -777,66 +801,14 @@
         val (_, ty') = norm_typ thy (c, ty);
       in
         thmtab
-        |> add_dtyps_of_type ty
-        |> add_funthms (c, ty)
-      end;
-    fun narrow_typs fungr =
-      let
-        (*
-        (*!!!remember whether any instantiation had been applied!!!*)
-        fun narrow_thms insts thms =
-          let
-            val (c_def, ty_def) =
-              (norm_typ thy o dest_Const o fst o Logic.dest_equals o Thm.prop_of o hd) thms;
-            val cs = fold (add_consts o snd o Logic.dest_equals o Thm.prop_of) thms [];
-            fun eval_inst c (inst, ty) =
-              let
-                val inst_ctxt = Sign.const_typargs thy (c, ty);
-                val inst_zip = fold (fn (v, ty) => (ty, (the o AList.lookup (op =) inst_ctxt) v)) inst
-                fun add_inst (ty_inst, ty_ctxt) =
-                  if Sign.typ_instance thy (ty_inst, ty_ctxt)
-                  then I
-                  else Sign.typ_match thy (ty_ctxt, ty_inst);
-              in fold add_inst inst_zip end;
-            val inst =
-              Vartab.empty
-              |> fold (fn c_ty as (c, ty) =>
-                    case ConstTab.lookup insts (norm_typ thy c_ty)
-                     of NONE => I
-                      | SOME inst => eval_inst c (inst, ty)) cs
-              |> Vartab.dest
-              |> map (fn (v, (_, ty)) => (v, ty));
-            val instT = map (fn (v, ty) =>
-                (Thm.ctyp_of thy (TVar v, Thm.ctyp_of thy ty))) inst;
-            val thms' =
-              if null inst then NONE thms else
-                map Thm.instantiate (instT, []) thms;
-            val inst' = if null inst then NONE
-              else SOME inst;
-          in (inst', thms') end;
-        fun narrow_css [c] (insts, fungr) =
-              (* HIER GEHTS WEITER *)
-              (insts, fungr)
-          | narrow_css css (insts, fungr) =
-              (insts, fungr)
-        *)
-        val css = rev (Graph.strong_conn fungr);
-      in
-        (ConstTab.empty, fungr)
-        (*|> fold narrow_css css*)
-        |> snd
+        |> add_dtyps_of_type ty'
+        |> add_funthms (c, ty')
       end;
   in
     thmtab_empty thy
     |> fold add_c cs
-(*     |> (apfst o apfst) narrow_typs  *)
   end;
 
-fun mk_thmtab cs thy =
-  thy
-  |> get_reset_dirty
-  |> snd
-  |> `(fn thy => mk_thmtab' thy cs);
 
 
 (** code attributes and setup **)
@@ -848,9 +820,10 @@
 in
   val _ = map (Context.add_setup o add_simple_attribute) [
     ("fun", add_fun),
+    ("nofun", del_fun),
     ("unfold", (fn thm => Codegen.add_unfold thm #> add_unfold thm)),
     ("inline", add_unfold),
-    ("nofold", del_unfold)
+    ("noinline", del_unfold)
   ]
 end; (*local*)