explicit table with constant types
authorhaftmann
Mon, 04 Sep 2006 08:18:00 +0200
changeset 20466 7c20ddbd911b
parent 20465 95f6d354b0ed
child 20467 210b326a03c9
explicit table with constant types
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/codegen_serializer.ML
src/Pure/Tools/codegen_theorems.ML
src/Pure/Tools/codegen_thingol.ML
--- a/src/Pure/Tools/codegen_package.ML	Mon Sep 04 08:17:28 2006 +0200
+++ b/src/Pure/Tools/codegen_package.ML	Mon Sep 04 08:18:00 2006 +0200
@@ -106,7 +106,7 @@
 
 (* theory data  *)
 
-type appgen = theory -> Sorts.algebra * (sort -> sort) -> CodegenTheorems.thmtab
+type appgen = theory -> ((sort -> sort) * Sorts.algebra) * Consts.T -> CodegenTheorems.thmtab
   -> bool * string list option -> (string * typ) * term list -> transact -> iterm * transact;
 
 type appgens = (int * (appgen * stamp)) Symtab.table;
@@ -302,7 +302,7 @@
 
 fun ensure_def_class thy algbr thmtab strct cls trns =
   let
-    fun defgen_class thy (algbr as (_, proj_sort)) thmtab strct cls trns =
+    fun defgen_class thy (algbr as ((proj_sort, _), _)) thmtab strct cls trns =
       case class_of_idf thy cls
        of SOME cls =>
             let
@@ -356,7 +356,7 @@
         ("generating type constructor " ^ quote tyco) tyco'
     |> pair tyco'
   end
-and exprgen_tyvar_sort thy (algbr as (_, proj_sort)) thmtab strct (v, sort) trns =
+and exprgen_tyvar_sort thy (algbr as ((proj_sort, _), _)) thmtab strct (v, sort) trns =
   trns
   |> fold_map (ensure_def_class thy algbr thmtab strct) (proj_sort sort)
   |-> (fn sort => pair (unprefix "'" v, sort))
@@ -377,7 +377,7 @@
       ||>> fold_map (exprgen_type thy algbr thmtab strct) tys
       |-> (fn (tyco, tys) => pair (tyco `%% tys));
 
-fun exprgen_typinst thy (algbr as (algebra, proj_sort)) thmtab strct (ty_ctxt, sort_decl) trns =
+fun exprgen_typinst thy (algbr as ((proj_sort, algebra), consts)) thmtab strct (ty_ctxt, sort_decl) trns =
   let
     val pp = Sign.pp thy;
     datatype inst =
@@ -410,28 +410,19 @@
     trns
     |> fold_map mk_dict insts
   end
-and exprgen_typinst_const thy algbr thmtab strct (c, ty_ctxt) trns =
-  let
-    val ty_decl = case CodegenTheorems.get_fun_thms thmtab (c, ty_ctxt)
-     of thms as thm :: _ => CodegenTheorems.extr_typ thy thm
-      | [] => (case AxClass.class_of_param thy c
-         of SOME class => (case ClassPackage.the_consts_sign thy class of (v, cs) =>
-              (Logic.varifyT o map_type_tfree (fn u as (w, _) =>
-                if w = v then TFree (v, [class]) else TFree u))
-              ((the o AList.lookup (op =) cs) c))
-          | NONE => Sign.the_const_type thy c);
-    val insts =
-      Vartab.empty
-      |> Sign.typ_match thy (ty_decl, ty_ctxt)
-      |> Vartab.dest
-      |> map (fn (_, (sort, ty)) => (ty, sort))
+and exprgen_typinst_const thy (algbr as (_, consts)) thmtab strct (c, ty_ctxt) trns =
+  let 
+    val idf = idf_of_const thy thmtab (c, ty_ctxt)
+    val ty_decl = Consts.declaration consts idf;
+    val insts = (op ~~ o apsnd (map (snd o dest_TVar)) oo pairself)
+      (curry (Consts.typargs consts) idf) (ty_ctxt, ty_decl);
   in
     trns
     |> fold_map (exprgen_typinst thy algbr thmtab strct) insts
   end
 and ensure_def_inst thy algbr thmtab strct (cls, tyco) trns =
   let
-    fun defgen_inst thy (algbr as (_, proj_sort)) thmtab strct inst trns =
+    fun defgen_inst thy (algbr as ((proj_sort, _), _)) thmtab strct inst trns =
       case inst_of_idf thy inst
        of SOME (class, tyco) =>
             let
@@ -442,7 +433,8 @@
               fun gen_suparity supclass trns =
                 trns
                 |> ensure_def_class thy algbr thmtab strct supclass
-                ||>> exprgen_typinst thy algbr thmtab strct (arity_typ, [supclass]);
+                ||>> exprgen_typinst thy algbr thmtab strct (arity_typ, [supclass])
+                |-> (fn (supclass, [Instance (supints, lss)]) => pair (supclass, (supints, lss)));
               fun gen_membr ((m0, ty0), (m, ty)) trns =
                 trns
                 |> ensure_def_const thy algbr thmtab strct (m0, ty0)
@@ -492,18 +484,17 @@
             |-> (fn _ => succeed Bot)
         | _ =>
             trns |> fail ("No class found for " ^ (quote o CodegenConsts.string_of_const_typ thy) (c, ty))
-    fun defgen_funs thy algbr thmtab strct c' trns =
+    fun defgen_funs thy (algbr as (_, consts)) thmtab strct c' trns =
       case CodegenTheorems.get_fun_thms thmtab ((the o const_of_idf thy) c')
        of eq_thms as eq_thm :: _ =>
             let
               val msg = cat_lines ("generating code for theorems " :: map string_of_thm eq_thms);
               val ty = (Logic.unvarifyT o CodegenTheorems.extr_typ thy) eq_thm;
-              val vs = (rev ooo fold_atyps)
-                (fn TFree v_sort => insert (op =) v_sort | _ => I) ty [];
+              val vs = (map dest_TFree o Consts.typargs consts) (c', ty);
               fun dest_eqthm eq_thm =
                 let
                   val ((t, args), rhs) =
-                    (apfst strip_comb o Logic.dest_equals o Logic.legacy_unvarify o prop_of) eq_thm;
+                    (apfst strip_comb o Logic.dest_equals o Logic.unvarify o prop_of) eq_thm;
                 in case t
                  of Const (c', _) => if c' = c then (args, rhs)
                      else error ("Illegal function equation for " ^ quote c
@@ -700,13 +691,35 @@
 (** code generation interfaces **)
 
 fun generate cs targets init gen it thy =
-  thy
-  |> CodegenTheorems.notify_dirty
-  |> `Code.get
-  |> (fn (modl, thy) =>
-        (start_transact init (gen thy (ClassPackage.operational_algebra thy) (CodegenTheorems.mk_thmtab thy cs)
-          (true, targets) it) modl, thy))
-  |-> (fn (x, modl) => Code.map (K modl) #> pair x);
+  let
+    val thmtab = CodegenTheorems.mk_thmtab thy cs;
+    val qnaming = NameSpace.qualified_names NameSpace.default_naming
+    val algebr = ClassPackage.operational_algebra thy;
+    fun ops_of_class class =
+      let
+        val (v, ops) = ClassPackage.the_consts_sign thy class;
+        val ops_tys = map (fn (c, ty) =>
+          (c, (Logic.varifyT o map_type_tfree (fn u as (w, _) =>
+            if w = v then TFree (v, [class]) else TFree u)) ty)) ops;
+      in
+        map (fn (c, ty) => (idf_of_const thy thmtab (c, ty), ty)) ops_tys
+      end;
+    val classops = maps ops_of_class (Sorts.classes (snd algebr));
+    val consttab = Consts.empty
+      |> fold (fn (c, ty) => Consts.declare qnaming
+           (((idf_of_const thy thmtab o CodegenConsts.typ_of_typinst thy) c, ty), true))
+           (CodegenTheorems.get_fun_typs thmtab)
+      |> fold (Consts.declare qnaming o rpair true) classops;
+    val algbr = (algebr, consttab);
+  in   
+    thy
+    |> CodegenTheorems.notify_dirty
+    |> `Code.get
+    |> (fn (modl, thy) =>
+          (start_transact init (gen thy algbr thmtab
+            (true, targets) it) modl, thy))
+    |-> (fn (x, modl) => Code.map (K modl) #> pair x)
+  end;
 
 fun consts_of t =
   fold_aterms (fn Const c => cons c | _ => I) t [];
--- a/src/Pure/Tools/codegen_serializer.ML	Mon Sep 04 08:17:28 2006 +0200
+++ b/src/Pure/Tools/codegen_serializer.ML	Mon Sep 04 08:18:00 2006 +0200
@@ -738,11 +738,11 @@
       | ml_from_def (name, CodegenThingol.Classinst ((class, (tyco, arity)), (suparities, memdefs))) =
           let
             val definer = if null arity then "val" else "fun"
-            fun from_supclass (supclass, ls) =
+            fun from_supclass (supclass, (supinst, lss)) =
               (Pretty.block o Pretty.breaks) [
                 ml_from_label supclass,
                 str "=",
-                ml_from_insts NOBR ls
+                ml_from_insts NOBR [Instance (supinst, lss)]
               ];
             fun from_memdef (m, e) =
               (Pretty.block o Pretty.breaks) [
--- a/src/Pure/Tools/codegen_theorems.ML	Mon Sep 04 08:17:28 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML	Mon Sep 04 08:18:00 2006 +0200
@@ -33,6 +33,7 @@
   val get_dtyp_of_cons: thmtab -> string * typ -> string option;
   val get_dtyp_spec: thmtab -> string
     -> ((string * sort) list * (string * typ list) list) option;
+  val get_fun_typs: thmtab -> ((string * typ list) * typ) list;
   val get_fun_thms: thmtab -> string * typ -> thm list;
 
   val pretty_funtab: theory -> thm list CodegenConsts.Consttab.table -> Pretty.T;
@@ -766,6 +767,18 @@
   (check_thms c o these o Consttab.lookup funtab
     o CodegenConsts.norminst_of_typ thy) c_ty;
 
+fun get_fun_typs (thy, (funtab, dtcotab), _) =
+  (Consttab.dest funtab
+  |> map (fn (c, thm :: _) => (c, extr_typ thy thm)
+           | (c as (name, _), []) => (c, case AxClass.class_of_param thy name
+         of SOME class => (case ClassPackage.the_consts_sign thy class of (v, cs) =>
+              (Logic.varifyT o map_type_tfree (fn u as (w, _) =>
+                if w = v then TFree (v, [class]) else TFree u))
+              ((the o AList.lookup (op =) cs) name))
+          | NONE => Sign.the_const_type thy name)))
+  @ (Consttab.keys dtcotab
+  |> AList.make (Sign.const_instance thy));
+
 fun pretty_funtab thy funtab =
   funtab
   |> CodegenConsts.Consttab.dest
--- a/src/Pure/Tools/codegen_thingol.ML	Mon Sep 04 08:17:28 2006 +0200
+++ b/src/Pure/Tools/codegen_thingol.ML	Mon Sep 04 08:18:00 2006 +0200
@@ -71,7 +71,7 @@
     | Class of class list * (vname * (string * itype) list)
     | Classmember of class
     | Classinst of (class * (string * (vname * sort) list))
-          * ((class * inst list) list
+          * ((class * (string * inst list list)) list
         * (string * iterm) list);
   type module;
   type transact;
@@ -388,7 +388,7 @@
   | Class of class list * (vname * (string * itype) list)
   | Classmember of class
   | Classinst of (class * (string * (vname * sort) list))
-        * ((class * inst list) list
+        * ((class * (string * inst list list)) list
       * (string * iterm) list);
 
 datatype node = Def of def | Module of node Graph.T;
@@ -698,56 +698,124 @@
   ) ((AList.make (Graph.get_node modl) o flat o Graph.strong_conn) modl)
 
 (*
+(*FIXME: graph-based approach is better.
+* build graph
+* implement flat_classops on sort level, not class level
+* flat_instances bleibt wie es ist
+*)
+fun flat_classops modl =
+  let
+    fun add_ancestry class anc =
+      let
+        val SOME (Class (super_classes, (v, ops))) = AList.lookup (op =) modl class
+        val super_classees' = filter (not o member (fn (c', (c, _)) => c = c') anc) super_classes;
+      in
+        [(class, ops)] @ anc
+        |> fold add_ancestry super_classees'
+      end;
+  in
+    Symtab.empty
+    |> fold (
+         fn (class, Class _) =>
+              Symtab.update_new (class, maps snd (add_ancestry class []))
+           | _ => I
+       ) modl
+    |> the oo Symtab.lookup
+  end;
+
+fun flat_instances modl =
+  let
+    fun add_ancestry instance instsss anc =
+      let
+        val SOME (Classinst (_, (super_instances, ops))) = AList.lookup (op =) modl instance;
+        val super_instances' = filter (not o member (eq_fst (op =)) anc) super_instances;
+        val ops' = map (apsnd (rpair instsss)) ops;
+        (*FIXME: build types*)
+      in
+        [(instance, ops')] @ anc
+        |> fold (fn (_, (instance, instss)) => add_ancestry instance (instss :: instsss)) super_instances'
+      end;
+  in
+    Symtab.empty
+    |> fold (
+         fn (instance, Classinst _) =>
+              Symtab.update_new (instance, maps snd (add_ancestry instance [] []))
+           | _ => I
+       ) modl
+    |> the oo Symtab.lookup
+  end;
+
+fun flat_fundef classops instdefs is_classop (eqs, (vs, ty)) =
+  let
+    fun fold_map_snd' f (x, ys) = fold_map (f x) ys;
+    fun fold_map_snd f (x, ys) = fold_map f ys #-> (fn zs => pair (x, zs));
+    val names =
+      Name.context
+      |> fold Name.declare
+           (fold (fn (rhs, lhs) => fold add_varnames rhs #> add_varnames lhs) eqs []);
+    val opmap = [] : (string * (string * (string * itype) list) list) list;
+    val (params, tys) = (split_list o maps snd o maps snd) opmap;
+    (*fun name_ops v' class = 
+      (fold_map o fold_map_snd')
+        (fn (class, v) => fn (c, ty) => Name.variants [c] #-> (fn [p] =>
+          pair (class, v') (c, (ty, p))))
+          (classops class);
+    val (opsmap, _) = (fold_map o fold_map_snd') name_ops vs names;
+    (* --> (iterm * itype) list *)*)
+    fun flat_inst (Instance (instance, instss)) =
+          let
+            val xs : (string * (iterm * (itype * inst list list list))) list = instdefs instance
+            fun mk_t (t, (ty, instsss)) =
+              (Library.foldl (fn (t, instss) => t `$$ map (fst o snd) ((maps o maps) flat_inst instss))
+                (t, instss :: instsss), ty)
+          in
+            map (apsnd mk_t) xs
+          end
+      | flat_inst (Context (classes, (v, k))) =
+          let
+            val _ : 'a = classops (hd classes);
+          in
+            []
+          end
+          (*
+            val parm_map = nth ((the o AList.lookup (op =) octxt) v)
+              (if k = ~1 then 0 else k);
+          in map (apfst IVar o swap o snd) (case classes
+           of class::_ => (the o AList.lookup (op =) parm_map) class
+            | _ => (snd o hd) parm_map)*)
+    and flat_iterm (e as IConst (c, (lss, ty))) =
+          if is_classop c then let
+            val tab = (maps o maps) flat_inst lss;
+            val SOME (t, _) = AList.lookup (op =) tab c;
+          in t end else let
+            val (es, tys) = (split_list o map snd) ((maps o maps) flat_inst lss)
+          in IConst (c, (replicate (length lss) [], tys `--> ty)) `$$ es end
+      | flat_iterm (e as IVar _) =
+          e
+      | flat_iterm (e1 `$ e2) =
+          flat_iterm e1 `$ flat_iterm e2
+      | flat_iterm (v_ty `|-> e) =
+          v_ty `|-> flat_iterm e
+      | flat_iterm (INum (k, e)) =
+          INum (k, flat_iterm e)
+      | flat_iterm (IChar (s, e)) =
+          IChar (s, flat_iterm e)
+      | flat_iterm (ICase (((de, dty), es), e)) =
+          ICase (((flat_iterm de, dty), map (pairself flat_iterm) es), flat_iterm e);
+    fun flat_eq (lhs, rhs) = (map IVar params @ lhs, flat_iterm rhs);
+  in (map flat_eq eqs, (map (apsnd (K [])) vs, tys `--> ty)) end;
+
 fun flat_funs_datatypes modl =
-  map (
-   fn def as (_, Datatype _) => def
-    | (name, Fun (eqs, (vs, ty))) => let
-          val vs = fold (fn (rhs, lhs) => fold add_varnames rhs #> add_varnames lhs) eqs [];
-          fun fold_map_snd f (x, ys) = fold_map f ys #-> (fn zs => pair (x, zs));
-          fun all_ops_of class = [] : (class * (string * itype) list) list
-            (*FIXME; itype within current context*);
-          fun name_ops class = 
-            (fold_map o fold_map_snd)
-              (fn (c, ty) => Name.variants [c] #-> (fn [v] => pair (c, (ty, v)))) (all_ops_of class);
-          (*FIXME: should contain superclasses only once*)
-          val (octxt, _) = (fold_map o fold_map_snd) name_ops vs
-            (Name.make_context vs);
-          (* --> (iterm * itype) list *)
-          fun flat_classlookup (Instance (inst, lss)) =
-                (case get_def modl inst
-                 of (Classinst (_, (suparities, ops)))
-                      => maps (maps flat_classlookup o snd) suparities @ map (apsnd flat_iterm) ops
-                  | _ => error ("Bad instance: " ^ quote inst))
-            | flat_classlookup (Context (classes, (v, k))) =
-                let
-                  val parm_map = nth ((the o AList.lookup (op =) octxt) v)
-                    (if k = ~1 then 0 else k);
-                in map (apfst IVar o swap o snd) (case classes
-                 of class::_ => (the o AList.lookup (op =) parm_map) class
-                  | _ => (snd o hd) parm_map)
-                end
-          and flat_iterm (e as IConst (c, (lss, ty))) =
-                let
-                  val (es, tys) = split_list ((maps o maps) flat_classlookup lss)
-                in IConst (c, ([], tys `--> ty)) `$$ es end
-                (*FIXME Eliminierung von Projektionen*)
-            | flat_iterm (e as IVar _) =
-                e
-            | flat_iterm (e1 `$ e2) =
-                flat_iterm e1 `$ flat_iterm e2
-            | flat_iterm (v_ty `|-> e) =
-                v_ty `|-> flat_iterm e
-            | flat_iterm (INum (k, e)) =
-                INum (k, flat_iterm e)
-            | flat_iterm (IChar (s, e)) =
-                IChar (s, flat_iterm e)
-            | flat_iterm (ICase (((de, dty), es), e)) =
-                ICase (((flat_iterm de, dty), map (pairself flat_iterm) es), flat_iterm e);
-        in
-          (name, Fun (map (fn (lhs, rhs) => (map flat_iterm lhs, flat_iterm rhs)) eqs,
-            ([], maps ((maps o maps) (map (fst o snd) o snd) o snd) octxt `--> ty)))
-        end
-  ) (flat_module modl);
+  let
+    val modl = flat_module modl;
+    val classops = flat_classops modl;
+    val instdefs = flat_instances modl;
+    val is_classop = is_some o AList.lookup (op =) modl;
+  in map_filter (
+   fn def as (_, Datatype _) => SOME def
+    | (name, Fun funn) => SOME (name, (Fun (flat_fundef classops instdefs is_classop funn)))
+    | _ => NONE
+  ) end;
 *)
 
 val add_deps_of_typparms =
@@ -817,9 +885,10 @@
       |> insert (op =) class
       |> insert (op =) tyco
       |> add_deps_of_typparms vs
-      |> fold (fn (supclass, ls) =>
+      |> fold (fn (supclass, (supinst, lss)) =>
             insert (op =) supclass
-            #> fold add_deps_of_classlookup ls
+            #> insert (op =) supinst
+            #> (fold o fold) add_deps_of_classlookup lss
       ) suparities
       |> fold (fn (name, e) =>
             insert (op =) name