src/Pure/Tools/codegen_theorems.ML
changeset 20191 b43fd26e1aaa
parent 20175 0a8ca32f6e64
child 20192 956cd30ef3be
--- a/src/Pure/Tools/codegen_theorems.ML	Tue Jul 25 16:43:33 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Tue Jul 25 16:43:47 2006 +0200
@@ -28,14 +28,15 @@
   val get_datatypes: theory -> string
     -> (((string * sort) list * (string * typ list) list) * thm list) option;
 
-  (*
   type thmtab;
-  val get_thmtab: (string * typ) list -> theory -> thmtab * theory;
-  val get_cons: thmtab -> string -> string option;
-  val get_dtyp: thmtab -> string -> (string * sort) list * (string * typ list) list;
-  val get_thms: thmtab -> string * typ -> typ * thm list;
-  *)
-  
+  val mk_thmtab: (string * typ) list -> theory -> thmtab * theory;
+  val norm_typ: theory -> string * typ -> string * typ;
+  val get_sortalgebra: thmtab -> Sorts.algebra;
+  val get_dtyp_of_cons: thmtab -> string * typ -> string option;
+  val get_dtyp_spec: thmtab -> string
+    -> ((string * sort) list * (string * typ list) list) option;
+  val get_fun_thms: thmtab -> string * typ -> thm list;
+
   val print_thms: theory -> unit;
 
   val init_obj: (thm * thm) * (thm * thm) -> theory -> theory;
@@ -87,7 +88,7 @@
 fun init_obj ((TrueI, FalseE), (conjI, atomize_eq)) thy =
   case CodegenTheoremsSetup.get thy
    of SOME _ => error "code generator already set up for object logic"
-    | NONE => 
+    | NONE =>
         let
           fun strip_implies t = (Logic.strip_imp_prems t, Logic.strip_imp_concl t);
           fun dest_TrueI thm =
@@ -120,7 +121,7 @@
                  #> apfst Term.dest_Const
                )
             |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "wrong premise")
-          fun dest_atomize_eq thm =
+          fun dest_atomize_eq thm=
             Drule.plain_prop_of thm
             |> Logic.dest_equals
             |> apfst (
@@ -238,7 +239,7 @@
       if v = v' orelse member (op =) set v then s
         else let
           val t = if i = ~1 then Free (v, ty) else Var (v_i, ty)
-        in 
+        in
           (maxidx + 1,  v :: set,
             (cterm_of thy t, cterm_of thy (Var ((v', maxidx), ty))) :: acc)
         end;
@@ -260,7 +261,7 @@
           drop (eq::eqs) (filter_out (matches eq) eqs')
   in drop [] eqs end;
 
-fun make_eq thy = 
+fun make_eq thy =
   let
     val ((_, atomize), _) = get_obj thy;
   in rewrite_rule [atomize] end;
@@ -368,7 +369,7 @@
 };
 
 fun mk_T ((dirty, notify), (preproc, (extrs, funthms))) =
-  T { dirty = dirty, notify = notify, preproc = preproc, extrs = extrs, funthms = funthms };
+  T { dirty = dirty, notify = notify, preproc= preproc, extrs = extrs, funthms = funthms };
 fun map_T f (T { dirty, notify, preproc, extrs, funthms }) =
   mk_T (f ((dirty, notify), (preproc, (extrs, funthms))));
 fun merge_T pp (T { dirty = dirty1, notify = notify1, preproc = preproc1, extrs = extrs1, funthms = funthms1 },
@@ -405,7 +406,7 @@
         Pretty.str "code generation theorems:",
         Pretty.str "function theorems:" ] @
         (*Pretty.fbreaks ( *)
-          map (fn (c, thms) => 
+          map (fn (c, thms) =>
             (Pretty.block o Pretty.fbreaks) (
               Pretty.str c :: map pretty_thm (rev thms)
             )
@@ -522,7 +523,7 @@
          (preprocs, thm :: unfolds)), y)))
   |> notify_all NONE;
 
-fun del_unfold thm = 
+fun del_unfold thm =
   map_data (fn (x, (preproc, y)) =>
        (x, (preproc |> map_preproc (fn (preprocs, unfolds) =>
          (preprocs, remove eq_thm thm unfolds)), y)))
@@ -546,6 +547,14 @@
 fun extr_typ thy thm = case dest_fun thy thm
  of (_, (ty, _)) => ty;
 
+fun rewrite_rhs conv thm = (case (Drule.strip_comb o cprop_of) thm
+ of (ct', [ct1, ct2]) => (case term_of ct'
+     of Const ("==", _) =>
+          Thm.equal_elim (combination (combination (reflexive ct') (reflexive ct1))
+            (conv ct2)) thm
+      | _ => raise ERROR "rewrite_rhs")
+  | _ => raise ERROR "rewrite_rhs");
+
 fun common_typ thy _ [] = []
   | common_typ thy _ [thm] = [thm]
   | common_typ thy extract_typ thms =
@@ -566,20 +575,13 @@
 fun preprocess thy thms =
   let
     fun burrow_thms f [] = []
-      | burrow_thms f thms = 
+      | burrow_thms f thms =
           thms
           |> Conjunction.intr_list
           |> f
           |> Conjunction.elim_list;
     fun cmp_thms (thm1, thm2) =
       not (Sign.typ_instance thy (extr_typ thy thm1, extr_typ thy thm2));
-    fun rewrite_rhs conv thm = (case (Drule.strip_comb o cprop_of) thm
-     of (ct', [ct1, ct2]) => (case term_of ct'
-         of Const ("==", _) =>
-              Thm.equal_elim (combination (combination (reflexive ct') (reflexive ct1))
-                (conv ct2)) thm
-          | _ => raise ERROR "rewrite_rhs")
-      | _ => raise ERROR "rewrite_rhs");
     fun unvarify thms =
       #1 (Variable.import true thms (ProofContext.init thy));
     val unfold_thms = Tactic.rewrite true (map (make_eq thy) (the_unfolds thy));
@@ -672,7 +674,7 @@
           val (_, lhs) = mk_lhs vs args;
         in (inj, mk_func thy (lhs, fals) :: dist) end;
     fun mk_eqs (vs, cos) =
-      let val cos' = rev cos 
+      let val cos' = rev cos
       in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
     fun mk_eq_thms tac vs_cos =
       map (fn t => Goal.prove_global thy [] []
@@ -693,42 +695,150 @@
     | _ => []
   else [];
 
-type thmtab = ((thm list Typtab.table Symtab.table
-  * string Symtab.table)
-  * ((string * sort) list * (string * typ list) list) Symtab.table);
+structure ConstTab = TableFun(type key = string * typ val ord = prod_ord fast_string_ord Term.typ_ord);
+structure ConstGraph = GraphFun(type key = string * typ val ord = prod_ord fast_string_ord Term.typ_ord);
+
+type thmtab = (theory * (thm list ConstGraph.T
+  * string ConstTab.table)
+  * (Sorts.algebra * ((string * sort) list * (string * typ list) list) Symtab.table));
+
+fun thmtab_empty thy = (thy, (ConstGraph.empty, ConstTab.empty),
+  (ClassPackage.operational_algebra thy, Symtab.empty));
+
+fun norm_typ thy (c, ty) =
+  (*more clever: use ty_insts*)
+  case get_first (fn ty' => if Sign.typ_instance thy (ty, (*Logic.varifyT*) ty')
+    then SOME ty' else NONE) (map #lhs (Theory.definitions_of thy c))
+   of NONE => (c, ty)
+    | SOME ty => (c, ty);
+
+fun get_sortalgebra (_, _, (algebra, _)) =
+  algebra;
 
-(*
-fun mk_thmtab thy cs =
+fun get_dtyp_of_cons (thy, (_, dtcotab), _) (c, ty) =
+  let
+    val (_, ty') = norm_typ thy (c, ty);
+  in ConstTab.lookup dtcotab (c, ty') end;
+
+fun get_dtyp_spec (_, _, (_, dttab)) tyco =
+  Symtab.lookup dttab tyco;
+
+fun has_fun_thms (thy, (fungr, _), _) (c, ty) =
+  let
+    val (_, ty') = norm_typ thy (c, ty);
+  in can (ConstGraph.get_node fungr) (c, ty') end;
+
+fun get_fun_thms (thy, (fungr, _), _) (c, ty) =
+  let
+    val (_, ty') = norm_typ thy (c, ty);
+  in these (try (ConstGraph.get_node fungr) (c, ty')) end;
+
+fun mk_thmtab' thy cs =
   let
-    fun add_c (c, ty) gr =
-    (*
-      Das ist noch viel komplizierter: Zyklen
-      und die aktuellen Instantiierungen muss man auch noch mitschleppen
-      man sieht: man braucht zusätzlich ein Mapping
-        c ~> [ty] (Symtab)
-      wobei dort immer die bislang allgemeinsten... ???
-    *)
-    (*
-      thm holen für bestimmten typ
-      typ dann behalten
-      typ normalisieren
-      damit haben wir den key
-      hier den check machen, ob schon prozessiert wurde
-      NEIN:
-        ablegen
-        consts der rechten Seiten
-        in die Rekursion gehen für alles
-      JA:
-        fertig
-    *)
-  in fold add_c cs Constgraph.empty end;
+    fun get_dtco_candidate ty =
+      case strip_type ty
+       of (_, Type (tyco, _)) => SOME tyco
+        | _ => NONE;
+    fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
+      | add_tycos _ = I;
+    val add_consts = fold_aterms
+      (fn Const c_ty => insert (op =) (norm_typ thy c_ty)
+         | _ => I);
+    fun add_dtyps_of_type ty thmtab =
+      let
+        val tycos = add_tycos ty [];
+        val tycos_new = filter (is_some o get_dtyp_spec thmtab) tycos;
+        fun add_dtyp_spec dtco (dtyp_spec as (vs, cs)) ((thy, (fungr, dtcotab), (algebra, dttab))) =
+          let
+            fun mk_co (c, tys) =
+              (c, Logic.varifyT (tys ---> Type (dtco, map TFree vs)));
+          in
+            (thy, (fungr, dtcotab |> fold (fn c_tys => ConstTab.update_new (mk_co c_tys, dtco)) cs),
+              (algebra, dttab |> Symtab.update_new (dtco, dtyp_spec)))
+          end;
+      in
+        thmtab
+        |> fold (fn tyco => case get_datatypes thy tyco
+             of SOME (dtyp_spec, _) => add_dtyp_spec tyco dtyp_spec
+              | NONE => I) tycos_new
+      end;
+    fun known thmtab (c, ty) =
+      is_some (get_dtyp_of_cons thmtab (c, ty)) orelse has_fun_thms thmtab (c, ty);
+    fun add_funthms (c, ty) (thmtab as (thy, (fungr, dtcotab), algebra_dttab))=
+      if known thmtab (norm_typ thy (c, ty)) then thmtab
+      else let
+        val thms = get_funs thy (c, ty)
+        val cs_dep = fold (add_consts o Thm.prop_of) thms [];
+      in
+        (thy, (fungr |> ConstGraph.new_node ((c, ty), thms)
+        , dtcotab), algebra_dttab)
+        |> fold add_c cs_dep
+      end
+    and add_c (c, ty) thmtab =
+      let
+        val (_, ty') = norm_typ thy (c, ty);
+      in
+        thmtab
+        |> add_dtyps_of_type ty
+        |> add_funthms (c, ty)
+      end;
+    fun narrow_typs fungr =
+      let
+        (*
+        (*!!!remember whether any instantiation had been applied!!!*)
+        fun narrow_thms insts thms =
+          let
+            val (c_def, ty_def) =
+              (norm_typ thy o dest_Const o fst o Logic.dest_equals o Thm.prop_of o hd) thms;
+            val cs = fold (add_consts o snd o Logic.dest_equals o Thm.prop_of) thms [];
+            fun eval_inst c (inst, ty) =
+              let
+                val inst_ctxt = Sign.const_typargs thy (c, ty);
+                val inst_zip = fold (fn (v, ty) => (ty, (the o AList.lookup (op =) inst_ctxt) v)) inst
+                fun add_inst (ty_inst, ty_ctxt) =
+                  if Sign.typ_instance thy (ty_inst, ty_ctxt)
+                  then I
+                  else Sign.typ_match thy (ty_ctxt, ty_inst);
+              in fold add_inst inst_zip end;
+            val inst =
+              Vartab.empty
+              |> fold (fn c_ty as (c, ty) =>
+                    case ConstTab.lookup insts (norm_typ thy c_ty)
+                     of NONE => I
+                      | SOME inst => eval_inst c (inst, ty)) cs
+              |> Vartab.dest
+              |> map (fn (v, (_, ty)) => (v, ty));
+            val instT = map (fn (v, ty) =>
+                (Thm.ctyp_of thy (TVar v, Thm.ctyp_of thy ty))) inst;
+            val thms' =
+              if null inst then NONE thms else
+                map Thm.instantiate (instT, []) thms;
+            val inst' = if null inst then NONE
+              else SOME inst;
+          in (inst', thms') end;
+        fun narrow_css [c] (insts, fungr) =
+              (* HIER GEHTS WEITER *)
+              (insts, fungr)
+          | narrow_css css (insts, fungr) =
+              (insts, fungr)
+        *)
+        val css = rev (Graph.strong_conn fungr);
+      in
+        (ConstTab.empty, fungr)
+        (*|> fold narrow_css css*)
+        |> snd
+      end;
+  in
+    thmtab_empty thy
+    |> fold add_c cs
+(*     |> (apfst o apfst) narrow_typs  *)
+  end;
 
-fun get_thmtab cs thy =
+fun mk_thmtab cs thy =
   thy
   |> get_reset_dirty
-  |-> (fn _ => I)
-  |> `mk_thmtab;
-*)
+  |> snd
+  |> `(fn thy => mk_thmtab' thy cs);
 
 
 (** code attributes and setup **)