slight improvement in serializer, stub for code generator theorems added
authorhaftmann
Fri, 17 Mar 2006 14:19:24 +0100
changeset 19280 5091dc43817b
parent 19279 48b527d0331b
child 19281 b411f25fff25
slight improvement in serializer, stub for code generator theorems added
src/Pure/Tools/ROOT.ML
src/Pure/Tools/class_package.ML
src/Pure/Tools/codegen_package.ML
src/Pure/Tools/codegen_serializer.ML
src/Pure/Tools/codegen_theorems.ML
--- a/src/Pure/Tools/ROOT.ML	Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/ROOT.ML	Fri Mar 17 14:19:24 2006 +0100
@@ -7,6 +7,9 @@
 (*class package*)
 use "class_package.ML";
 
+(*code generator theorems*)
+use "codegen_theorems.ML";
+
 (*code generator, 1st generation*)
 use "../codegen.ML";
 
--- a/src/Pure/Tools/class_package.ML	Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/class_package.ML	Fri Mar 17 14:19:24 2006 +0100
@@ -65,7 +65,8 @@
   name_axclass: string,
   intro: thm option,
   var: string,
-  consts: (string * typ) list
+  consts: (string * (string * typ)) list
+    (*locale parameter ~> toplevel const*)
 };
 
 structure ClassData = TheoryDataFun (
@@ -95,7 +96,7 @@
             Pretty.str ("class variable: " ^ var),
             (Pretty.block o Pretty.fbreaks) (
               Pretty.str "constants: "
-              :: map (fn (c, ty) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
+              :: map (fn (_, (c, ty)) => Pretty.str (c ^ " :: " ^ Sign.string_of_typ thy ty)) consts
             )
           ]
       in
@@ -120,7 +121,9 @@
     of NONE => error ("undeclared operational class " ^ quote class)
      | SOME data => data;
 
-fun is_class thy cls =
+val is_class = is_some oo lookup_class_data;
+
+fun is_operational_class thy cls =
   lookup_class_data thy cls
   |> Option.map (not o null o #consts)
   |> the_default false;
@@ -129,7 +132,7 @@
   let
     val classes = Sign.classes_of thy;
     fun get_sort class =
-      if is_class thy class
+      if is_operational_class thy class
       then [class]
       else operational_sort_of thy (Sorts.superclasses classes class);
   in
@@ -144,14 +147,14 @@
     Sorts.superclasses (Sign.classes_of thy) class
     |> operational_sort_of thy
   else
-    error ("no syntactic class: " ^ class);
+    error ("no class: " ^ class);
 
 fun get_superclass_derivation thy (subclass, superclass) =
   if subclass = superclass
     then SOME [subclass]
     else case Graph.find_paths ((fst o fst o ClassData.get) thy) (subclass, superclass)
       of [] => NONE
-       | (p::_) => (SOME o filter (is_class thy)) p;
+       | (p::_) => (SOME o filter (is_operational_class thy)) p;
 
 fun the_ancestry thy classes =
   let
@@ -170,14 +173,19 @@
   map_type_tfree (fn u as (w, _) =>
     if w = v then ty_subst else TFree u);
 
+fun the_parm_map thy class =
+  let
+    val data = the_class_data thy class
+  in (#consts data) end;
+
 fun the_consts_sign thy class =
   let
     val data = the_class_data thy class
-  in (#var data, #consts data) end;
+  in (#var data, (map snd o #consts) data) end;
 
 fun the_inst_sign thy (class, tyco) =
   let
-    val _ = if is_class thy class then () else error ("no syntactic class: " ^ class);
+    val _ = if is_operational_class thy class then () else error ("no operational class: " ^ class);
     val arity =
       Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class];
     val clsvar = (#var o the_class_data thy) class;
@@ -216,13 +224,16 @@
     tab
     |> Symtab.update (class, [])),
     consttab
-    |> fold (fn (c, _) => Symtab.update (c, class)) consts
+    |> fold (fn (_, (c, _)) => Symtab.update (c, class)) consts
   ));
 
 fun add_inst_data (class, inst) =
   ClassData.map (fn ((gr, tab), consttab) =>
-     ((gr, tab |>
-    (Symtab.map_entry class (AList.update (op =) inst))), consttab));
+    let
+      val undef_supclasses = class :: (filter (Symtab.defined tab) (Graph.all_succs gr [class]));
+    in
+     ((gr, tab |> fold (fn class => Symtab.map_entry class (AList.update (op =) inst)) undef_supclasses), consttab)
+    end);
 
 
 (* name handling *)
@@ -234,7 +245,7 @@
   map (fn class => (the_class_data thy class; class)) (Sign.certify_sort thy sort);
 
 fun intern_class thy =
-  certify_class thy o Sign.intern_class thy;
+certify_class thy o Sign.intern_class thy;
 
 fun intern_sort thy =
   certify_sort thy o Sign.intern_sort thy;
@@ -356,18 +367,19 @@
       |> map (#name_axclass o the_class_data thy)
       |> Sorts.certify_sort (Sign.classes_of thy)
       |> null ? K (Sign.defaultS thy);
-    val supcs = (Library.flat o map (snd o the_consts_sign thy) o the_ancestry thy)
-      supclasses;
     val expr = if null supclasses
       then Locale.empty
       else
        (Locale.Merge o map (Locale.Locale o #name_locale o the_class_data thy)) supclasses;
+    val mapp_sup = AList.make
+      (the o AList.lookup (op =) ((Library.flat o map (the_parm_map thy) o the_ancestry thy) supclasses))
+      ((map (fst o fst) o Locale.parameters_of_expr thy) expr);
     fun extract_tyvar_consts thy name_locale =
       let
         fun extract_tyvar_name thy tys =
           fold (curry add_typ_tfrees) tys []
           |> (fn [(v, sort)] =>
-                    if Sorts.sort_le (Sign.classes_of thy) (swap (sort, supsort))
+              if Sorts.sort_le (Sign.classes_of thy) (swap (sort, supsort))
                     then v
                     else error ("illegal sort constraint on class type variable: " ^ Sign.string_of_sort thy sort)
                | [] => error ("no class type variable")
@@ -377,10 +389,9 @@
           |> map (apsnd Syntax.unlocalize_mixfix)
         val v = (extract_tyvar_name thy o map (snd o fst)) consts1;
         val consts2 = map ((apfst o apsnd) (subst_clsvar v (TFree (v, [])))) consts1;
-      in (v, chop (length supcs) consts2) end;
+      in (v, chop (length mapp_sup) consts2) end;
     fun add_consts v raw_cs_sup raw_cs_this thy =
       let
-        val mapp_sub = map2 (fn ((c, _), _) => pair c) raw_cs_sup supcs
         fun add_global_const ((c, ty), syn) thy =
           thy
           |> Sign.add_consts_i [(c, ty |> subst_clsvar v (TFree (v, Sign.defaultS thy)), syn)]
@@ -388,7 +399,6 @@
       in
         thy
         |> fold_map add_global_const raw_cs_this
-        |-> (fn mapp_this => pair (mapp_sub @ mapp_this, map snd mapp_this))
       end;
     fun extract_assumes thy name_locale cs_mapp =
       let
@@ -400,7 +410,7 @@
       in
         (map prep_asm o Locale.local_asms_of thy) name_locale
       end;
-    fun add_global_constraint v class (c, ty) thy =
+    fun add_global_constraint v class (_, (c, ty)) thy =
       thy
       |> Sign.add_const_constraint_i (c, SOME (subst_clsvar v (TFree (v, [class])) ty));
     fun mk_const thy class v (c, ty) =
@@ -412,15 +422,15 @@
           `(fn thy => extract_tyvar_consts thy name_locale)
     #-> (fn (v, (raw_cs_sup, raw_cs_this)) =>
           add_consts v raw_cs_sup raw_cs_this
-    #-> (fn (cs_map, cs_this) =>
-          `(fn thy => extract_assumes thy name_locale cs_map)
+    #-> (fn mapp_this =>
+          `(fn thy => extract_assumes thy name_locale (mapp_sup @ mapp_this))
     #-> (fn loc_axioms =>
           add_axclass_i (bname, supsort) loc_axioms
     #-> (fn (name_axclass, (_, ax_axioms)) =>
-          fold (add_global_constraint v name_axclass) cs_this
-    #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, intro, v, cs_this))
+          fold (add_global_constraint v name_axclass) mapp_this
+    #> add_class_data (name_locale, (supclasses, name_locale, name_axclass, intro, v, mapp_this))
     #> prove_interpretation_i (NameSpace.base name_locale, [])
-          (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (supcs @ cs_this))
+          (Locale.Locale name_locale) (map (SOME o mk_const thy name_axclass v) (map snd (mapp_sup @ mapp_this)))
           ((ALLGOALS o resolve_tac) ax_axioms)
     #> pair ctxt
     )))))
@@ -490,7 +500,7 @@
         val data = the_class_data theory class;
         val subst_ty = map_type_tfree (fn (var as (v, _)) =>
           if #var data = v then ty_inst else TFree var)
-      in (map (apsnd subst_ty) o #consts) data end;
+      in (map (apsnd subst_ty o snd) o #consts) data end;
     val cs = (Library.flat o map get_consts) classes;
     fun get_remove_contraint c thy =
       let
@@ -570,7 +580,7 @@
     val _ = writeln ("sub " ^ name)
     val suplocales = (fn Locale.Merge es => map (fn Locale.Locale n => n) es) expr;
     val _ = writeln ("super " ^ commas suplocales)
-    fun get_c name = 
+    fun get_c name =
       (map (NameSpace.base o fst o fst) o Locale.parameters_of thy) name;
     fun get_a name =
       (map (NameSpace.base o fst o fst) o Locale.local_asms_of thy) name;
@@ -663,7 +673,7 @@
     fun mk_lookup (sort_def, (Type (tyco, tys))) =
           map (fn class => Instance ((class, tyco),
             map2 (curry mk_lookup)
-              ((fst o the o AList.lookup (op =) (the_instances thy class)) tyco)
+              (map (operational_sort_of thy) (Sorts.mg_domain (Sign.classes_arities_of thy) tyco [class]))
               tys)
           ) sort_def
       | mk_lookup (sort_def, TVar ((vname, _), sort_use)) =
@@ -673,7 +683,7 @@
               in Lookup (deriv, (vname, classindex)) end;
           in map mk_look sort_def end;
   in
-    sortctxt
+ sortctxt
     |> map (tab_lookup o fst)
     |> map (apfst (operational_sort_of thy))
     |> filter (not o null o fst)
@@ -690,7 +700,7 @@
         | SOME class =>
             let
               val data = the_class_data thy class;
-              val sign = (Type.varifyT o the o AList.lookup (op =) (#consts data)) c;
+              val sign = (Type.varifyT o the o AList.lookup (op =) ((map snd o #consts) data)) c;
               val match_tab = Sign.typ_match thy (sign, typ_def) Vartab.empty;
               val v : string = case Vartab.lookup match_tab (#var data, 0)
                 of SOME (_, TVar ((v, _), _)) => v;
@@ -751,13 +761,18 @@
     Scan.optional (P.$$$ "+" |-- P.!!! (Scan.repeat1 P.context_element)) [] ||
   Scan.repeat1 P.context_element >> pair Locale.empty);
 
+val class_subP = P.name -- Scan.repeat (P.$$$ "+" |-- P.name) >> (op ::);
+val class_bodyP = P.!!! (Scan.repeat1 P.context_element);
+
 val classP =
   OuterSyntax.command classK "operational type classes" K.thy_decl (
     P.name --| P.$$$ "="
-    -- Scan.optional (Scan.repeat1 (P.name --| P.$$$ "+")) []
-    -- Scan.optional (P.!!! (Scan.repeat1 P.context_element)) []
-      >> (Toplevel.theory_context
-          o (fn ((bname, supclasses), elems) => class bname supclasses elems)));
+    -- (
+      class_subP --| P.$$$ "+" -- class_bodyP
+      || class_subP >> rpair []
+      || class_bodyP >> pair []
+    ) >> (Toplevel.theory_context
+          o (fn (bname, (supclasses, elems)) => class bname supclasses elems)));
 
 val instanceP =
   OuterSyntax.command instanceK "prove type arity or subclass relation" K.thy_goal ((
--- a/src/Pure/Tools/codegen_package.ML	Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Fri Mar 17 14:19:24 2006 +0100
@@ -136,7 +136,7 @@
 
 fun eq_typ thy (ty1, ty2) =
   Sign.typ_instance thy (ty1, ty2)
-  andalso Sign.typ_instance thy (ty2, ty1);
+    andalso Sign.typ_instance thy (ty2, ty1);
 
 fun is_overloaded thy c = case Defs.specifications_of (Theory.defs_of thy) c
  of [] => true
@@ -754,7 +754,8 @@
               fun gen_membr (m, ty) trns =
                 trns
                 |> mk_fun thy tabs true (m, ty)
-                |-> (fn NONE => error ("could not derive definition for member " ^ quote m)
+                |-> (fn NONE => error ("could not derive definition for member "
+                          ^ quote m ^ " :: " ^ Sign.string_of_typ thy ty)
                       | SOME (funn, ty_use) =>
                     (fold_map o fold_map) (exprgen_classlookup thy tabs)
                        (ClassPackage.extract_classlookup_member thy (ty, ty_use))
@@ -913,7 +914,7 @@
 
 fun eqextr_defs thy (deftab, _) (c, ty) =
   Option.mapPartial (get_first (fn (ty', (thm, _)) =>
-    if eq_typ thy (ty, ty')
+    if Sign.typ_instance thy (ty, ty') 
     then SOME ([thm], ty')
     else NONE
   )) (Symtab.lookup deftab c);
--- a/src/Pure/Tools/codegen_serializer.ML	Fri Mar 17 10:04:27 2006 +0100
+++ b/src/Pure/Tools/codegen_serializer.ML	Fri Mar 17 14:19:24 2006 +0100
@@ -413,7 +413,7 @@
           str ")"
         ]
       end;
-    fun ml_from_sortlookup fxy ls =
+    fun ml_from_sortlookup fxy lss =
       let
         fun from_label l =
           Pretty.block [str "#", ml_from_label l];
@@ -437,10 +437,10 @@
               from_lookup BR classes (str v)
           | from_classlookup fxy (Lookup (classes, (v, i))) =
               from_lookup BR (string_of_int (i+1) :: classes) (str v)
-      in case ls
+      in case lss
        of [] => str "()"
-        | [l] => from_classlookup fxy l
-        | ls => (Pretty.list "(" ")" o map (from_classlookup NOBR)) ls
+        | [ls] => from_classlookup fxy ls
+        | lss => (Pretty.list "(" ")" o map (from_classlookup NOBR)) lss
       end;
     fun ml_from_tycoexpr fxy (tyco, tys) =
       let
@@ -738,7 +738,9 @@
                 ml_from_label supclass
                 :: str "="
                 :: (str o resolv) supinst
-                :: map (ml_from_sortlookup NOBR) lss
+                :: (if null lss andalso (not o null) arity
+                     then [str "()"]
+                     else map (ml_from_sortlookup NOBR) lss)
               );
             fun from_memdef (m, ((m', def), lss)) =
               (ml_from_funs [(m', def)], (Pretty.block o Pretty.breaks) (
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Pure/Tools/codegen_theorems.ML	Fri Mar 17 14:19:24 2006 +0100
@@ -0,0 +1,148 @@
+(*  Title:      Pure/Tools/CODEGEN_THEOREMS.ML
+    ID:         $Id$
+    Author:     Florian Haftmann, TU Muenchen
+
+Theorems used for code generation.
+*)
+
+signature CODEGEN_THEOREMS =
+sig
+  val add_notify: (string option -> theory -> theory) -> theory -> theory;
+  val add_preproc: (theory -> thm list -> thm list) -> theory -> theory;
+  val add_funn: thm -> theory -> theory;
+  val add_pred: thm -> theory -> theory;
+  val add_unfold: thm -> theory -> theory;
+  val preprocess: theory -> thm list -> thm list;
+  val preprocess_term: theory -> term -> term;
+end;
+
+structure CodegenTheorems: CODEGEN_THEOREMS =
+struct
+
+(** auxiliary **)
+
+fun dest_funn thm =
+  case try (fst o dest_Const o fst o strip_comb o fst o Logic.dest_equals o prop_of) thm
+   of SOME c => SOME (c, thm)
+    | NONE => NONE;
+
+fun dest_pred thm =
+  case try (fst o dest_Const o fst o strip_comb o snd o Logic.dest_implies o prop_of) thm
+   of SOME c => SOME (c, thm)
+    | NONE => NONE;
+
+
+(** theory data **)
+
+datatype procs = Procs of {
+  preprocs: (serial * (theory -> thm list -> thm list)) list,
+  notify: (serial * (string option -> theory -> theory)) list
+};
+
+fun mk_procs (preprocs, notify) = Procs { preprocs = preprocs, notify = notify };
+fun map_procs f (Procs { preprocs, notify }) = mk_procs (f (preprocs, notify));
+fun merge_procs _ (Procs { preprocs = preprocs1, notify = notify1 },
+  Procs { preprocs = preprocs2, notify = notify2 }) =
+    mk_procs (AList.merge (op =) (K true) (preprocs1, preprocs2),
+      AList.merge (op =) (K true) (notify1, notify2));
+
+datatype codethms = Codethms of {
+  funns: thm list Symtab.table,
+  preds: thm list Symtab.table,
+  unfolds: thm list
+};
+
+fun mk_codethms ((funns, preds), unfolds) =
+  Codethms { funns = funns, preds = preds, unfolds = unfolds };
+fun map_codethms f (Codethms { funns, preds, unfolds }) =
+  mk_codethms (f ((funns, preds), unfolds));
+fun merge_codethms _ (Codethms { funns = funns1, preds = preds1, unfolds = unfolds1 },
+  Codethms { funns = funns2, preds = preds2, unfolds = unfolds2 }) =
+    mk_codethms ((Symtab.join (K (uncurry (fold (insert eq_thm)))) (funns1, funns2),
+        Symtab.join (K (uncurry (fold (insert eq_thm)))) (preds1, preds2)),
+          fold (insert eq_thm) unfolds1 unfolds2);
+
+datatype T = T of {
+  procs: procs,
+  codethms: codethms
+};
+
+fun mk_T (procs, codethms) = T { procs = procs, codethms = codethms };
+fun map_T f (T { procs, codethms }) = mk_T (f (procs, codethms));
+fun merge_T pp (T { procs = procs1, codethms = codethms1 },
+  T { procs = procs2, codethms = codethms2 }) =
+    mk_T (merge_procs pp (procs1, procs2), merge_codethms pp (codethms1, codethms2));
+
+structure CodegenTheorems = TheoryDataFun
+(struct
+  val name = "Pure/CodegenTheorems";
+  type T = T;
+  val empty = mk_T (mk_procs ([], []),
+    mk_codethms ((Symtab.empty, Symtab.empty), []));
+  val copy = I;
+  val extend = I;
+  val merge = merge_T;
+  fun print _ _ = ();
+end);
+
+val _ = Context.add_setup CodegenTheorems.init;
+
+
+(** notifiers and preprocessors **)
+
+fun add_notify f =
+  CodegenTheorems.map (map_T (fn (procs, codethms) =>
+    (procs |> map_procs (fn (preprocs, notify) =>
+      (preprocs, (serial (), f) :: notify)), codethms)));
+
+fun notify_all c thy =
+  fold (fn f => f c) (((fn Procs { notify, ... } => map snd notify)
+    o (fn T { procs, ... } => procs) o CodegenTheorems.get) thy) thy;
+
+fun add_preproc f =
+  CodegenTheorems.map (map_T (fn (procs, codethms) =>
+    (procs |> map_procs (fn (preprocs, notify) =>
+      ((serial (), f) :: preprocs, notify)), codethms)))
+  #> notify_all NONE;
+
+fun preprocess thy =
+  fold (fn f => f thy) (((fn Procs { preprocs, ... } => map snd preprocs)
+    o (fn T { procs, ... } => procs) o CodegenTheorems.get) thy);
+
+fun preprocess_term thy t =
+  let
+    val x = Free (variant (add_term_names (t, [])) "x", fastype_of t);
+    (* fake definition *)
+    val eq = setmp quick_and_dirty true (SkipProof.make_thm thy)
+      (Logic.mk_equals (x, t));
+    fun err () = error "preprocess_term: bad preprocessor"
+  in case map prop_of (preprocess thy [eq]) of
+      [Const ("==", _) $ x' $ t'] => if x = x' then t' else err ()
+    | _ => err ()
+  end;
+
+fun add_unfold thm =
+  CodegenTheorems.map (map_T (fn (procs, codethms) =>
+    (procs, codethms |> map_codethms (fn (defs, unfolds) =>
+      (defs, thm :: unfolds)))))
+
+fun add_funn thm =
+  case dest_funn thm
+   of SOME (c, thm) =>
+    CodegenTheorems.map (map_T (fn (procs, codethms) =>
+      (procs, codethms |> map_codethms (fn ((funns, preds), unfolds) =>
+        ((funns |> Symtab.default (c, []) |> Symtab.map (fn thms => thms @ [thm]), preds), unfolds)))))
+    | NONE => error ("not a function equation: " ^ string_of_thm thm);
+
+fun add_pred thm =
+  case dest_pred thm
+   of SOME (c, thm) =>
+    CodegenTheorems.map (map_T (fn (procs, codethms) =>
+      (procs, codethms |> map_codethms (fn ((funns, preds), unfolds) =>
+        ((funns, preds |> Symtab.default (c, []) |> Symtab.map (fn thms => thms @ [thm])), unfolds)))))
+    | NONE => error ("not a predicate clause: " ^ string_of_thm thm);
+
+
+(** isar **)
+
+end; (* struct *)