src/Pure/Tools/codegen_theorems.ML
changeset 20386 d1cbe5aa6bf2
parent 20353 d73e49780ef2
child 20394 21227c43ba26
--- a/src/Pure/Tools/codegen_theorems.ML	Mon Aug 14 13:46:20 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Mon Aug 14 13:46:21 2006 +0200
@@ -24,12 +24,8 @@
   val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
   val preprocess: theory -> thm list -> thm list;
 
-  val get_datatypes: theory -> string
-    -> (((string * sort) list * (string * typ list) list) * thm list) option;
-
   type thmtab;
   val mk_thmtab: theory -> (string * typ) list -> thmtab;
-  val norm_typ: theory -> string * typ -> string * typ;
   val get_sortalgebra: thmtab -> Sorts.algebra;
   val get_dtyp_of_cons: thmtab -> string * typ -> string option;
   val get_dtyp_spec: thmtab -> string
@@ -41,8 +37,6 @@
   val init_obj: (thm * thm) * (thm * thm) -> theory -> theory;
   val debug: bool ref;
   val debug_msg: ('a -> string) -> 'a -> 'a;
-  structure ConstTab: TABLE;
-  structure ConstGraph: GRAPH;
 end;
 
 structure CodegenTheorems: CODEGEN_THEOREMS =
@@ -214,8 +208,6 @@
     |> beta_norm
   end;
 
-val lower_name = translate_string Symbol.to_ascii_lower o Symbol.alphanum;
-
 fun canonical_tvars thy thm =
   let
     fun mk_inst (v_i as (v, i), (v', sort)) (s as (maxidx, set, acc)) =
@@ -227,7 +219,7 @@
             (ctyp_of thy ty, ctyp_of thy (TVar ((v', maxidx), sort))) :: acc)
         end;
     fun tvars_of thm = (fold_types o fold_atyps)
-      (fn TVar (v_i as (v, i), sort) => cons (v_i, (lower_name v, sort))
+      (fn TVar (v_i as (v, i), sort) => cons (v_i, (CodegenNames.purify_var v, sort))
         | _ => I) (prop_of thm) [];
     val maxidx = Thm.maxidx_of thm + 1;
     val (_, _, inst) = fold mk_inst (tvars_of thm) (maxidx + 1, [], []);
@@ -244,7 +236,7 @@
             (cterm_of thy t, cterm_of thy (Var ((v', maxidx), ty))) :: acc)
         end;
     fun vars_of thm = fold_aterms
-      (fn Var (v_i as (v, i), ty) => cons (v_i, (lower_name v, ty))
+      (fn Var (v_i as (v, i), ty) => cons (v_i, (CodegenNames.purify_var v, ty))
         | _ => I) (prop_of thm) [];
     val maxidx = Thm.maxidx_of thm + 1;
     val (_, _, inst) = fold mk_inst (vars_of thm) (maxidx + 1, [], []);
@@ -721,61 +713,128 @@
       map2 check_vars_rhs thms tts; thms)
   end;
 
-structure ConstTab = TableFun(type key = string * typ val ord = prod_ord fast_string_ord Term.typ_ord);
-structure ConstGraph = GraphFun(type key = string * typ val ord = prod_ord fast_string_ord Term.typ_ord);
-
-type thmtab = (theory * (thm list ConstGraph.T
-  * string ConstTab.table)
+structure Consttab = CodegenConsts.Consttab;
+type thmtab = (theory * (thm list Consttab.table
+  * string Consttab.table)
   * (Sorts.algebra * ((string * sort) list * (string * typ list) list) Symtab.table));
 
-fun thmtab_empty thy = (thy, (ConstGraph.empty, ConstTab.empty),
+fun thmtab_empty thy = (thy, (Consttab.empty, Consttab.empty),
   (ClassPackage.operational_algebra thy, Symtab.empty));
 
-fun norm_typ thy (c, ty) =
-  (*more clever: use ty_insts*)
-  case get_first (fn ty' => if Sign.typ_instance thy (ty, (*Logic.varifyT*) ty')
-    then SOME ty' else NONE) (map #lhs (Theory.definitions_of thy c))
-   of NONE => (c, ty)
-    | SOME ty => (c, ty);
-
 fun get_sortalgebra (_, _, (algebra, _)) =
   algebra;
 
-fun get_dtyp_of_cons (thy, (_, dtcotab), _) (c, ty) =
-  let
-    val (_, ty') = norm_typ thy (c, ty);
-  in ConstTab.lookup dtcotab (c, ty') end;
+fun get_dtyp_of_cons (thy, (_, dtcotab), _) =
+  Consttab.lookup dtcotab o CodegenConsts.norminst_of_typ thy;
+
+fun get_dtyp_spec (_, _, (_, dttab)) =
+  Symtab.lookup dttab;
+
+fun has_fun_thms (thy, (funtab, _), _) =
+  is_some o Consttab.lookup funtab o CodegenConsts.norminst_of_typ thy;
 
-fun get_dtyp_spec (_, _, (_, dttab)) tyco =
-  Symtab.lookup dttab tyco;
+fun get_fun_thms (thy, (funtab, _), _) (c_ty as (c, _)) =
+  (check_thms c o these o Consttab.lookup funtab
+    o CodegenConsts.norminst_of_typ thy) c_ty;
+
+fun pretty_funtab thy funtab =
+  funtab
+  |> CodegenConsts.Consttab.dest
+  |> map (fn (c, thms) =>
+       (Pretty.block o Pretty.fbreaks) (
+         (Pretty.str o CodegenConsts.string_of_const thy) c
+         :: map Display.pretty_thm thms
+       ))
+  |> Pretty.chunks;
 
-fun has_fun_thms (thy, (fungr, _), _) (c, ty) =
+fun constrain_funtab thy funtab =
   let
-    val (_, ty') = norm_typ thy (c, ty);
-  in can (ConstGraph.get_node fungr) (c, ty') end;
-
-fun get_fun_thms (thy, (fungr, _), _) (c, ty) =
-  let
-    val (_, ty') = norm_typ thy (c, ty);
-  in these (try (ConstGraph.get_node fungr) (c, ty')) |> check_thms c end;
+    fun max k [] = k
+      | max k (l::ls) = max (if k < l then l else k) ls;
+    fun mk_consttyps funtab =
+      CodegenConsts.Consttab.empty
+      |> CodegenConsts.Consttab.fold (fn (c, thm :: _) =>
+           CodegenConsts.Consttab.update_new (c, extr_typ thy thm) | (_, []) => I) funtab
+    fun mk_typescheme_of typtab (c, ty) =
+      CodegenConsts.Consttab.lookup typtab (CodegenConsts.norminst_of_typ thy (c, ty));
+    fun incr_indices (c, thms) maxidx =
+      let
+        val thms' = map (Thm.incr_indexes maxidx) thms;
+        val maxidx' = Int.max
+          (maxidx, max ~1 (map Thm.maxidx_of thms') + 1);
+      in (thms', maxidx') end;
+    fun consts_of_eqs thms =
+      let
+        fun terms_of_eq thm =
+          let
+            val (lhs, rhs) = (Logic.dest_equals o Drule.plain_prop_of) thm
+          in rhs :: (snd o strip_comb) lhs end;
+      in (fold o fold_aterms) (fn Const c => insert (eq_pair (op =) (Type.eq_type Vartab.empty)) c | _ => I)
+        (maps terms_of_eq thms) []
+      end;
+    val typscheme_of =
+      mk_typescheme_of (mk_consttyps funtab);
+    val tsig = Sign.tsig_of thy;
+    fun unify_const (c, ty) (env, maxidx) =
+      case typscheme_of (c, ty)
+       of SOME ty_decl => let
+            (*val _ = writeln "UNIFY";
+            val _ = writeln (CodegenConsts.string_of_const_typ thy (c, ty))*)
+            val ty_decl' = Logic.incr_tvar maxidx ty_decl;
+            (*val _ = writeln "WITH";
+            val _ = writeln (CodegenConsts.string_of_const_typ thy (c, ty_decl'))*)
+            val maxidx' = Int.max (Term.maxidx_of_typ ty_decl' + 1, maxidx);
+            (*val _ = writeln ("  " ^ string_of_int maxidx ^ " +> " ^ string_of_int maxidx');*)
+          in Type.unify tsig (ty_decl', ty) (env, maxidx') end
+        | NONE => (env, maxidx);
+    fun apply_unifier unif [] = []
+      | apply_unifier unif (thms as thm :: _) =
+          let
+            val ty = extr_typ thy thm;
+            val ty' = Envir.norm_type unif ty;
+            val env = Type.typ_match (Sign.tsig_of thy) (ty, ty') Vartab.empty;
+            val inst = Thm.instantiate (Vartab.fold (fn (x_i, (sort, ty)) =>
+              cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [], []);
+          in map (Drule.zero_var_indexes o inst) thms end;
+(*     val _ = writeln "(1)";  *)
+(*     val _ = (Pretty.writeln o pretty_funtab thy) funtab;  *)
+    val (funtab', maxidx) =
+      CodegenConsts.Consttab.fold_map incr_indices funtab 0;
+(*     val _ = writeln "(2)";
+ *     val _ = (Pretty.writeln o pretty_funtab thy) funtab';
+ *)
+    val (unif, _) =
+      CodegenConsts.Consttab.fold (fold unify_const o consts_of_eqs o snd)
+        funtab' (Vartab.empty, maxidx);
+(*     val _ = writeln "(3)";  *)
+    val funtab'' =
+      CodegenConsts.Consttab.map (apply_unifier unif) funtab';
+(*     val _ = writeln "(4)";
+ *     val _ = (Pretty.writeln o pretty_funtab thy) funtab'';
+ *)
+  in funtab'' end;
 
 fun mk_thmtab thy cs =
   let
     fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
       | add_tycos _ = I;
-    val add_consts = fold_aterms
-      (fn Const c_ty => insert (op =) (norm_typ thy c_ty)
-         | _ => I);
+    fun consts_of ts =
+      Consttab.empty
+      |> (fold o fold_aterms)
+           (fn Const c_ty => Consttab.update (CodegenConsts.norminst_of_typ thy c_ty, ())
+             | _ => I) ts
+      |> Consttab.keys;
     fun add_dtyps_of_type ty thmtab =
       let
         val tycos = add_tycos ty [];
         val tycos_new = filter (is_none o get_dtyp_spec thmtab) tycos;
-        fun add_dtyp_spec dtco (dtyp_spec as (vs, cs)) ((thy, (fungr, dtcotab), (algebra, dttab))) =
+        fun add_dtyp_spec dtco (dtyp_spec as (vs, cs)) ((thy, (funtab, dtcotab), (algebra, dttab))) =
           let
             fun mk_co (c, tys) =
-              (c, Logic.varifyT (tys ---> Type (dtco, map TFree vs)));
+              CodegenConsts.norminst_of_typ thy (c, Logic.varifyT (tys ---> Type (dtco, map TFree vs)));
           in
-            (thy, (fungr, dtcotab |> fold (fn c_tys => ConstTab.update_new (mk_co c_tys, dtco)) cs),
+            (thy, (funtab, dtcotab |> fold (fn c_tys =>
+              Consttab.update_new (mk_co c_tys, dtco)) cs),
               (algebra, dttab |> Symtab.update_new (dtco, dtyp_spec)))
           end;
       in
@@ -786,27 +845,25 @@
       end;
     fun known thmtab (c, ty) =
       is_some (get_dtyp_of_cons thmtab (c, ty)) orelse has_fun_thms thmtab (c, ty);
-    fun add_funthms (c, ty) (thmtab as (thy, (fungr, dtcotab), algebra_dttab))=
-      if known thmtab (norm_typ thy (c, ty)) then thmtab
+    fun add_funthms (c, ty) (thmtab as (thy, (funtab, dtcotab), algebra_dttab))=
+      if known thmtab (c, ty) then thmtab
       else let
         val thms = get_funs thy (c, ty)
-        val cs_dep = fold (add_consts o Thm.prop_of) thms [];
+        val cs_dep = (consts_of o map Thm.prop_of) thms;
       in
-        (thy, (fungr |> ConstGraph.new_node ((c, ty), thms)
+        (thy, (funtab |> Consttab.update_new (CodegenConsts.norminst_of_typ thy (c, ty), thms)
         , dtcotab), algebra_dttab)
         |> fold add_c cs_dep
       end
-    and add_c (c, ty) thmtab =
-      let
-        val (_, ty') = norm_typ thy (c, ty);
-      in
-        thmtab
-        |> add_dtyps_of_type ty'
-        |> add_funthms (c, ty')
-      end;
+    and add_c (c_tys as (c, tys)) thmtab =
+      thmtab
+      |> add_dtyps_of_type (snd (CodegenConsts.typ_of_typinst thy c_tys))
+      |> fold (add_funthms o CodegenConsts.typ_of_typinst thy)
+           (CodegenConsts.insts_of_classop thy c_tys);
   in
     thmtab_empty thy
-    |> fold add_c cs
+    |> fold (add_c o CodegenConsts.norminst_of_typ thy) cs
+    |> (fn (a, (funtab, b), c) => (a, (funtab |> constrain_funtab thy, b), c))
   end;