improved class handling
authorhaftmann
Tue, 06 Dec 2005 16:07:25 +0100
changeset 18360 a2c9506b62a7
parent 18359 02a830bab542
child 18361 3126d01e9e35
improved class handling
src/Pure/Tools/class_package.ML
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/codegen_serializer.ML
src/Pure/Tools/codegen_thingol.ML
--- a/src/Pure/Tools/class_package.ML	Tue Dec 06 16:07:10 2005 +0100
+++ b/src/Pure/Tools/class_package.ML	Tue Dec 06 16:07:25 2005 +0100
@@ -11,7 +11,7 @@
   val the_consts: theory -> class -> string list
   val the_tycos: theory -> class -> (string * string) list
 
-  val is_class: theory -> class -> bool
+  val syntactic_sort_of: theory -> sort -> sort
   val get_arities: theory -> sort -> string -> sort list
   val get_superclasses: theory -> class -> class list
   val get_const_sign: theory -> string -> string -> typ
@@ -140,21 +140,28 @@
 
 (* class queries *)
 
-fun is_class thy = is_some o lookup_class_data thy;
-
-fun filter_class thy = filter (is_class thy);
+fun is_class thy cls = lookup_class_data thy cls |> Option.map (not o null o #consts) |> the_default false;
 
-fun assert_class thy class =
-  if is_class thy class then class
-  else error ("not a class: " ^ quote class);
+fun syntactic_sort_of thy sort =
+  let
+    val classes = Sign.classes_of thy;
+    fun get_sort cls =
+      if is_class thy cls
+      then [cls]
+      else syntactic_sort_of thy (Sorts.superclasses classes cls);
+  in
+    map get_sort sort
+    |> Library.flat
+    |> Sorts.norm_sort classes
+  end;
 
 fun get_arities thy sort tycon =
   Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort
-  |> (map o map) (assert_class thy);
+  |> map (syntactic_sort_of thy);
 
 fun get_superclasses thy class =
   Sorts.superclasses (Sign.classes_of thy) class
-  |> filter_class thy;
+  |> syntactic_sort_of thy;
 
 
 (* instance queries *)
@@ -202,7 +209,7 @@
 
 fun extract_sortctxt thy ty =
   (typ_tfrees o Type.no_tvars) ty
-  |> map (apsnd (filter_class thy))
+  |> map (apsnd (syntactic_sort_of thy))
   |> filter (not o null o snd);
 
 datatype sortlookup = Instance of (class * string) * sortlookup list list
@@ -224,7 +231,7 @@
     fun mk_lookup (sort_def, (Type (tycon, tys))) =
           let
             val arity_lookup = map2 (curry mk_lookup)
-              (map (filter_class thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort_def)) tys
+              (map (syntactic_sort_of thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tycon sort_def)) tys
           in map (fn class => Instance ((class, tycon), arity_lookup)) sort_def end
       | mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
           let
@@ -235,7 +242,7 @@
   in
     extract_sortctxt thy ((fst o Type.freeze_thaw_type) raw_typ_def)
     |> map (tab_lookup o fst)
-    |> map (apfst (filter_class thy))
+    |> map (apfst (syntactic_sort_of thy))
     |> filter (not o null o fst)
     |> map mk_lookup
   end;
--- a/src/Pure/Tools/codegen_package.ML	Tue Dec 06 16:07:10 2005 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Tue Dec 06 16:07:25 2005 +0100
@@ -595,7 +595,7 @@
 fun exprgen_sort_default thy defs sort trns =
   trns
   |> fold_map (ensure_def_class thy defs)
-       (sort |> filter (ClassPackage.is_class thy) |> map (idf_of_name thy nsp_class))
+       (sort |> ClassPackage.syntactic_sort_of thy |> map (idf_of_name thy nsp_class))
   |-> (fn sort => succeed sort)
 
 fun exprgen_type_default thy defs (TVar _) trns =
@@ -763,17 +763,11 @@
 fun defgen_clsmem thy (defs as (_, _, _)) f trns =
   case name_of_idf thy nsp_mem f
    of SOME clsmem =>
-        let
-          val _ = debug 10 (fn _ => "CLSMEM " ^ quote clsmem) ();
-          val _ = debug 10 (fn _ => (the o ClassPackage.lookup_const_class thy) clsmem) ();
-          val cls = idf_of_name thy nsp_class ((the o ClassPackage.lookup_const_class thy) clsmem);
-          val ty = ClassPackage.get_const_sign thy "'a" clsmem;
-        in
-          trns
-          |> debug 5 (fn _ => "trying defgen class member for " ^ quote f)
-          |> (invoke_cg_type thy defs o devarify_type) ty
-          |-> (fn ty => succeed (Classmember (cls, "a", ty), []))
-        end
+        trns
+        |> debug 5 (fn _ => "trying defgen class member for " ^ quote f)
+        |> ensure_def_class thy defs (idf_of_name thy nsp_class ((the o ClassPackage.lookup_const_class thy) clsmem))
+        ||>> (invoke_cg_type thy defs o devarify_type) (ClassPackage.get_const_sign thy "'a" clsmem)
+        |-> (fn (cls, ty) => succeed (Classmember (cls, "a", ty), []))
     | _ =>
         trns |> fail ("no class member found for " ^ quote f)
 
@@ -1332,7 +1326,7 @@
     |> (if is_some consts then generate_code (the consts) else pair [])
     |-> (fn [] => `(serializer' NONE o #modl o CodegenData.get)
           | consts => `(serializer' (SOME consts) o #modl o CodegenData.get))
-    |-> (fn code => (setmp print_mode [] (use_code o Pretty.output) code; I))
+    |-> (fn code => ((use_code o Pretty.output) code; I))
   end;
 
 
--- a/src/Pure/Tools/codegen_serializer.ML	Tue Dec 06 16:07:10 2005 +0100
+++ b/src/Pure/Tools/codegen_serializer.ML	Tue Dec 06 16:07:25 2005 +0100
@@ -596,7 +596,7 @@
         |> translate_string replace_invalid
         |> suffix_it
         |> (fn name' => if name = name' then NONE else SOME name')
-    end;
+      end;
     fun ml_from_module (name, ps) =
       Pretty.chunks ([
         Pretty.str ("structure " ^ name ^ " = "),
@@ -958,6 +958,7 @@
       end;
     fun haskell_from_classes defs =
       let
+        val _ = writeln ("IDS: " ^ (commas o map fst) defs)
         fun mk_member (f, ty) =
           Pretty.block [
             Pretty.str (f ^ " ::"),
@@ -1019,7 +1020,7 @@
             haskell_from_sctxt arity,
             Pretty.str ((upper_first o resolv) clsname),
             Pretty.str " ",
-            Pretty.str ((upper_first o resolv) tyco),
+            haskell_from_type NOBR (IType (tyco, (map (IVarT o rpair [] o fst)) arity)),
             Pretty.str " where",
             Pretty.fbrk,
             Pretty.chunks (map (fn (member, const) =>
@@ -1045,7 +1046,24 @@
           Pretty.fbrk,
           Pretty.chunks (separate (Pretty.str "") ps)
         ];
-    fun haskell_validator s = NONE;
+    fun haskell_validator name =
+      let
+        fun replace_invalid c =
+          if (Char.isAlphaNum o the o Char.fromString) c orelse c = "'"
+          andalso not (NameSpace.separator = c)
+          then c
+          else "_"
+        fun suffix_it name =
+          name
+          |> member (op =) CodegenThingol.prims ? suffix "'"
+          |> has_prim prims ? suffix "'"
+          |> (fn name' => if name = name' then name else suffix_it name')
+      in
+        name
+        |> translate_string replace_invalid
+        |> suffix_it
+        |> (fn name' => if name = name' then NONE else SOME name')
+      end;
     fun eta_expander "Pair" = 2
       | eta_expander "if" = 3
       | eta_expander s =
--- a/src/Pure/Tools/codegen_thingol.ML	Tue Dec 06 16:07:10 2005 +0100
+++ b/src/Pure/Tools/codegen_thingol.ML	Tue Dec 06 16:07:25 2005 +0100
@@ -60,6 +60,7 @@
   type gen_defgen = string -> transact -> (def * string list) transact_fin;
   val pretty_def: def -> Pretty.T;
   val pretty_module: module -> Pretty.T; 
+  val pretty_deps: module -> Pretty.T;
   val empty_module: module;
   val get_def: module -> string -> def;
   val merge_module: module * module -> module;
@@ -539,6 +540,24 @@
           Pretty.block [Pretty.str name, Pretty.str " :=", Pretty.brk 1, pretty_def def]
   in pretty ("//", Module modl) end;
 
+fun pretty_deps modl =
+  let
+    fun one_node key =
+      (Pretty.block o Pretty.fbreaks) (
+        Pretty.str key
+        :: (map (fn s => Pretty.str ("<- " ^ s)) o Graph.imm_preds modl) key
+        @ (map (fn s => Pretty.str ("-> " ^ s)) o Graph.imm_succs modl) key
+        @ (the_list oo Option.mapPartial) ((fn Module modl' => SOME (pretty_deps modl') | _ => NONE) o Graph.get_node modl) (SOME key)
+      );
+  in
+    modl
+    |> Graph.strong_conn
+    |> List.concat
+    |> rev
+    |> map one_node
+    |> Pretty.chunks
+  end;
+
 
 (* name handling *)
 
@@ -1139,10 +1158,9 @@
           let
             val _ = writeln ("class 1");
             val varnames_ctxt =
-              sortctxt
-              |> length o Library.flat o map snd
-              |> Term.invent_names ((vars_of_iexprs o map snd) ds @ (vars_of_ipats o Library.flat o map fst) ds) "d"
-              |> unflat (map snd sortctxt);
+              dig
+                (Term.invent_names ((vars_of_iexprs o map snd) ds @ (vars_of_ipats o Library.flat o map fst) ds) "d" o length)
+                (map snd sortctxt);
             val _ = writeln ("class 2");
             val vname_alist = map2 (fn (vt, sort) => fn vs => (vt, vs ~~ sort)) sortctxt varnames_ctxt;
             val _ = writeln ("class 3");
@@ -1293,7 +1311,7 @@
       | seri prfx ds =
           s_def (resolver prfx) (map (fn (name, Def def) => (resolver prfx (prfx @ [name] |> NameSpace.pack), def)) ds)
   in
-    s_module (name_root, (map (seri [])
+    setmp print_mode [] s_module (name_root, (map (seri [])
       ((map (AList.make (Graph.get_node module)) o rev o Graph.strong_conn) module)))
   end;