src/Pure/Tools/codegen_package.ML
changeset 19136 00ade10f611d
parent 19111 1f6112de1d0f
child 19150 1457d810b408
--- a/src/Pure/Tools/codegen_package.ML	Sat Feb 25 15:11:35 2006 +0100
+++ b/src/Pure/Tools/codegen_package.ML	Sat Feb 25 15:19:00 2006 +0100
@@ -13,7 +13,6 @@
     -> string * typ -> (thm list * typ) option;
   type eqextr_default = theory -> auxtab
     -> string * typ -> ((thm list * term option) * typ) option;
-  type defgen;
   type appgen = theory -> auxtab
     -> (string * typ) * term list -> CodegenThingol.transact
     -> CodegenThingol.iexpr * CodegenThingol.transact;
@@ -26,7 +25,7 @@
     -> theory -> theory;
   val add_prim_tyco: xstring -> (string * string)
     -> theory -> theory;
-  val add_prim_const: xstring * string option -> (string * string)
+  val add_prim_const: xstring -> (string * string)
     -> theory -> theory;
   val add_prim_i: string -> (string * CodegenThingol.prim list)
     -> theory -> theory;
@@ -39,7 +38,10 @@
     -> theory -> theory;
   val set_int_tyco: string -> theory -> theory;
 
-  val codegen_incr: term -> theory -> (string * CodegenThingol.def) list * theory;
+  val codegen_incr: term -> theory -> (CodegenThingol.iexpr * (string * CodegenThingol.def) list) * theory;
+  val is_dtcon: string -> bool;
+  val consts_of_idfs: theory -> string list -> (string * (string * typ)) list
+
   val get_ml_fun_datatype: theory -> (string -> string)
     -> ((string * CodegenThingol.funn) list -> Pretty.T)
         * ((string * CodegenThingol.datatyp) list -> Pretty.T);
@@ -141,6 +143,11 @@
   Sign.typ_instance thy (ty1, ty2)
   andalso Sign.typ_instance thy (ty2, ty1);
 
+fun is_overloaded thy c = case Defs.specifications_of (Theory.defs_of thy) c
+ of [] => false
+  | [(ty, _)] => not (eq_typ thy (ty, Sign.the_const_type thy c))
+  | _ => true;
+
 structure InstNameMangler = NameManglerFun (
   type ctxt = theory;
   type src = string * (class * string);
@@ -154,25 +161,24 @@
 );
 
 structure ConstNameMangler = NameManglerFun (
-  type ctxt = theory * deftab;
-  type src = string * (typ * typ);
-  val ord = prod_ord string_ord (prod_ord Term.typ_ord Term.typ_ord);
-  fun mk (thy, deftab) ((c, (ty_decl, ty)), i) =
+  type ctxt = theory;
+  type src = string * typ;
+  val ord = prod_ord string_ord Term.typ_ord;
+  fun mk thy ((c, ty), i) =
     let
-      val thyname = case (get_first
-          (fn (ty', (_, thyname)) => if eq_typ thy (ty, ty') then SOME thyname else NONE)
-            o these o Symtab.lookup deftab) c
-        of SOME thyname => thyname
-         | _ => (NameSpace.drop_base o alias_get thy o fst o dest_Type o hd o fst o strip_type) ty
       val c' = idf_of_name thy nsp_overl c;
-      val c'' = NameSpace.append thyname (NameSpace.append nsp_overl (NameSpace.base c'));
+      val prefix = case (AList.lookup (eq_typ thy)
+          (Defs.specifications_of (Theory.defs_of thy) c)) ty
+       of SOME thyname => NameSpace.append thyname nsp_overl
+        | NONE => NameSpace.drop_base c';
+      val c'' = NameSpace.append prefix (NameSpace.base c');
       fun mangle (Type (tyco, tys)) =
             (NameSpace.base o alias_get thy) tyco :: Library.flat (List.mapPartial mangle tys) |> SOME
         | mangle _ =
             NONE
     in
       Vartab.empty
-      |> Sign.typ_match thy (ty_decl, ty)
+      |> Sign.typ_match thy (Sign.the_const_type thy c, ty)
       |> map (snd o snd) o Vartab.dest
       |> List.mapPartial mangle
       |> Library.flat
@@ -182,8 +188,14 @@
       |> curry (op ^ o swap) ((implode oo replicate) i "'")
     end;
   fun is_valid _ _ = true;
-  fun maybe_unique _ _ = NONE;
-  fun re_mangle _ dst = error ("no such constant: " ^ quote dst);
+  fun maybe_unique thy (c, ty) = 
+    if is_overloaded thy c
+      then NONE
+      else (SOME o idf_of_name thy nsp_const) c;
+  fun re_mangle thy idf =
+   case name_of_idf thy nsp_const idf
+    of NONE => error ("no such constant: " ^ quote idf)
+     | SOME c => (c, Sign.the_const_type thy c);
 );
 
 structure DatatypeconsNameMangler = NameManglerFun (
@@ -212,13 +224,12 @@
 );
 
 type auxtab = (deftab * string Symtab.table)
-  * (InstNameMangler.T * ((typ * typ list) Symtab.table * ConstNameMangler.T)
+  * (InstNameMangler.T * (typ list Symtab.table * ConstNameMangler.T)
   * DatatypeconsNameMangler.T);
 type eqextr = theory -> auxtab
   -> string * typ -> (thm list * typ) option;
 type eqextr_default = theory -> auxtab
   -> string * typ -> ((thm list * term option) * typ) option;
-type defgen = theory -> auxtab -> gen_defgen;
 type appgen = theory -> auxtab
   -> (string * typ) * term list -> transact -> iexpr * transact;
 
@@ -395,10 +406,10 @@
   let
     fun get_overloaded (c, ty) =
       case Symtab.lookup overltab1 c
-       of SOME (ty_decl, tys) =>
+       of SOME tys =>
             (case find_first (curry (Sign.typ_instance thy) ty) tys
-             of SOME ty' => ConstNameMangler.get (thy, deftab) overltab2
-                  (c, (ty_decl, ty')) |> SOME
+             of SOME ty' => ConstNameMangler.get thy overltab2
+                  (c, ty') |> SOME
               | _ => NONE)
         | _ => NONE
     fun get_datatypecons (c, ty) =
@@ -422,8 +433,7 @@
         case dest_nsp nsp_overl idf
          of SOME _ =>
               idf
-              |> ConstNameMangler.rev (thy, deftab) overltab2
-              |> apsnd snd
+              |> ConstNameMangler.rev thy overltab2
               |> SOME
           | NONE => NONE
       );
@@ -607,7 +617,7 @@
   in
     trns
     |> debug 4 (fn _ => "generating class " ^ quote cls)
-    |> gen_ensure_def [("class", defgen_class thy tabs)] ("generating class " ^ quote cls) cls'
+    |> ensure_def [("class", defgen_class thy tabs)] ("generating class " ^ quote cls) cls'
     |> pair cls'
   end
 and ensure_def_tyco thy tabs tyco trns =
@@ -638,7 +648,7 @@
   in
     trns
     |> debug 4 (fn _ => "generating type constructor " ^ quote tyco)
-    |> gen_ensure_def [("datatype", defgen_datatype thy tabs)] ("generating type constructor " ^ quote tyco) tyco'
+    |> ensure_def [("datatype", defgen_datatype thy tabs)] ("generating type constructor " ^ quote tyco) tyco'
     |> pair tyco'
   end
 and exprgen_tyvar_sort thy tabs (v, sort) trns =
@@ -715,14 +725,15 @@
               val (arity, memdefs) = ClassPackage.the_inst_sign thy (class, tyco);
               fun gen_suparity supclass trns =
                 trns
-                |> (fold_map o fold_map) (exprgen_classlookup thy tabs)
-                     (ClassPackage.extract_classlookup_inst thy (supclass, tyco) supclass)
+                |> ensure_def_class thy tabs supclass
                 ||>> ensure_def_inst thy tabs (supclass, tyco)
-                |-> (fn (ls, _) => pair (supclass, ls));
+                ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs)
+                      (ClassPackage.extract_classlookup_inst thy (supclass, tyco) supclass)
+                |-> (fn ((supclass, inst), lss) => pair (supclass, (inst, lss)));
               fun gen_membr (m, ty) trns =
                 trns
                 |> mk_fun thy tabs (m, ty)
-                |-> (fn SOME funn => pair (idf_of_name thy nsp_mem m, funn)
+                |-> (fn SOME funn => pair (idf_of_name thy nsp_mem m, (idf_of_name thy nsp_mem m ^ "'", funn))
                       | NONE => error ("could not derive definition for member " ^ quote m));
             in
               trns
@@ -743,7 +754,7 @@
   in
     trns
     |> debug 4 (fn _ => "generating instance " ^ quote cls ^ " / " ^ quote tyco)
-    |> gen_ensure_def [("instance", defgen_inst thy tabs)]
+    |> ensure_def [("instance", defgen_inst thy tabs)]
          ("generating instance " ^ quote cls ^ " / " ^ quote tyco) inst
     |> pair inst
   end
@@ -791,30 +802,28 @@
   in
     trns
     |> debug 4 (fn _ => "generating constant " ^ quote c)
-    |> gen_ensure_def ((single o get_defgen) idf) ("generating constant " ^ quote c) idf
+    |> ensure_def ((single o get_defgen) idf) ("generating constant " ^ quote c) idf
     |> pair idf
   end
 and exprgen_term thy tabs (Const (f, ty)) trns =
       trns
       |> appgen thy tabs ((f, ty), [])
       |-> (fn e => pair e)
-  (* | exprgen_term thy tabs (Var ((v, 0), ty)) trns =
-      trns
-      |> (exprgen_type thy tabs) ty
-      |-> (fn ty => pair (IVarE (v, ty)))
-  | exprgen_term thy tabs (Var ((_, _), _)) trns =
-      error "Var with index greater 0 encountered during code generation" *)
   | exprgen_term thy tabs (Var _) trns =
       error "Var encountered during code generation"
   | exprgen_term thy tabs (Free (v, ty)) trns =
       trns
       |> exprgen_type thy tabs ty
       |-> (fn ty => pair (IVarE (v, ty)))
-  | exprgen_term thy tabs (Abs (v, ty, t)) trns =
-      trns
-      |> exprgen_type thy tabs ty
-      ||>> exprgen_term thy tabs (subst_bound (Free (v, ty), t))
-      |-> (fn (ty, e) => pair (IVarE (v, ty) `|-> e))
+  | exprgen_term thy tabs (Abs (abs as (_, ty, _))) trns =
+      let
+        val (v, t) = Term.variant_abs abs
+      in
+        trns
+        |> exprgen_type thy tabs ty
+        ||>> exprgen_term thy tabs t
+        |-> (fn (ty, e) => pair (IVarE (v, ty) `|-> e))
+      end
   | exprgen_term thy tabs (t as t1 $ t2) trns =
       let
         val (t', ts) = strip_comb t
@@ -925,7 +934,25 @@
     val idf = idf_of_const thy tabs (c, ty);
   in
     trns
-    |> gen_ensure_def [("wfrec", (K o succeed) Undef)] ("generating wfrec") idf
+    |> ensure_def [("wfrec", (K o succeed) Undef)] ("generating wfrec") idf
+    |> exprgen_type thy tabs ty'
+    ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs)
+           (ClassPackage.extract_classlookup thy (c, ty))
+    ||>> exprsgen_type thy tabs [ty_def]
+    ||>> exprgen_term thy tabs tf
+    ||>> exprgen_term thy tabs tx
+    |-> (fn ((((_, ls), [ty]), tf), tx) => pair (IConst ((idf, ty), ls) `$ tf `$ tx))
+  end;
+
+
+fun appgen_wfrec thy tabs ((c, ty), [_, tf, tx]) trns =
+  let
+    val ty_def = (op ---> o apfst tl o strip_type o Sign.the_const_type thy) c;
+    val ty' = (op ---> o apfst tl o strip_type) ty;
+    val idf = idf_of_const thy tabs (c, ty);
+  in
+    trns
+    |> ensure_def [("wfrec", (K o succeed) Undef)] ("generating wfrec") idf
     |> exprgen_type thy tabs ty'
     ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs)
            (ClassPackage.extract_classlookup thy (c, ty))
@@ -992,8 +1019,6 @@
 
 fun mk_tabs thy =
   let
-    fun get_specifications thy c =
-      Defs.specifications_of (Theory.defs_of thy) c;
     fun extract_defs thy =
       let
         fun dest thm =
@@ -1027,56 +1052,39 @@
             (fn (tyco, thyname) => InstNameMangler.declare thy (thyname, (cls, tyco))) clsinsts)
                  (ClassPackage.get_classtab thy)
       |-> (fn _ => I);
-    fun add_monoeq thy deftab (overltab1, overltab2) =
-      let
-        val c = "op =";
-        val ty = Sign.the_const_type thy c;
-        fun inst dtco = 
-          map_atyps (fn _ => Type (dtco,
-            (map (fn (v, sort) => TVar ((v, 0), sort)) o fst o the o get_datatype thy) dtco)) ty
-        val dtcos = fold (insert (op =) o snd) (get_all_datatype_cons thy) [];
-        val tys = map inst dtcos;
-      in
-        (overltab1
-         |> Symtab.update_new (c, (ty, tys)),
-         overltab2
-         |> fold (fn ty' => ConstNameMangler.declare (thy, deftab)
-              (c, (ty, ty')) #> snd) tys)
-      end;
-    (* über *alle*: (map fst o NameSpace.dest_table o Consts.space_of o Sign.consts_of) thy
-       * (c, ty) reicht dann zur zünftigen Deklaration
-       * somit fliegt ein Haufen Grusch raus, deftab bleibt allerdings wegen thyname
-      fun mk_overltabs thy =
+    fun mk_overltabs thy =
       (Symtab.empty, ConstNameMangler.empty)
       |> Symtab.fold
-          (fn c => if (is_none o ClassPackage.lookup_const_class thy) c
-            then case get_specifications thy c
-             of [_] => NONE
-              | tys => fold
-                (fn (overltab1, overltab2) => (
-                    overltab1
-                    |> Symtab.update_new (c, (Sign.the_const_type thy c, tys)),
-                    overltab2
-                    |> fold (fn (ty, (_, thyname)) => ConstNameMangler.declare (thy, deftab)
-                         (c, (Sign.the_const_type thy c, ty)) #> snd) tys))
-                else I
-          ) deftab
-      |> add_monoeq thy deftab;*)
-    fun mk_overltabs thy deftab =
-      (Symtab.empty, ConstNameMangler.empty)
-      |> Symtab.fold
-          (fn (c, [_]) => I
-            | (c, tytab) =>
-                if (is_none o ClassPackage.lookup_const_class thy) c
-                then (fn (overltab1, overltab2) => (
-                    overltab1
-                    |> Symtab.update_new (c, (Sign.the_const_type thy c, map fst tytab)),
-                    overltab2
-                    |> fold (fn (ty, (_, thyname)) => ConstNameMangler.declare (thy, deftab)
-                         (c, (Sign.the_const_type thy c, ty)) #> snd) tytab))
-                else I
-          ) deftab
-      |> add_monoeq thy deftab;
+          (fn (c, _) =>
+            let
+              val deftab = Defs.specifications_of (Theory.defs_of thy) c
+              val is_overl = (is_none o ClassPackage.lookup_const_class thy) c
+               andalso case deftab
+               of [] => false
+                | [(ty, _)] => not (eq_typ thy (ty, Sign.the_const_type thy c))
+                | _ => true;
+            in if is_overl then (fn (overltab1, overltab2) => (
+              overltab1
+              |> Symtab.update_new (c, map fst deftab),
+              overltab2
+              |> fold_map (fn (ty, _) => ConstNameMangler.declare thy (c, ty)) deftab
+              |-> (fn _ => I))) else I
+            end) ((#2 o #constants o Consts.dest o #consts o Sign.rep_sg) thy)
+      |> (fn (overltab1, overltab2) =>
+            let
+              val c = "op =";
+              val ty = Sign.the_const_type thy c;
+              fun inst dtco = 
+                map_atyps (fn _ => Type (dtco,
+                  (map (fn (v, sort) => TVar ((v, 0), sort)) o fst o the o get_datatype thy) dtco)) ty
+              val dtcos = fold (insert (op =) o snd) (get_all_datatype_cons thy) [];
+              val tys = map inst dtcos;
+            in
+              (overltab1
+               |> Symtab.update_new (c, tys),
+               overltab2
+               |> fold (fn ty => ConstNameMangler.declare thy (c, ty) #> snd) tys)
+            end);
     fun mk_dtcontab thy =
       DatatypeconsNameMangler.empty
       |> fold_map
@@ -1095,7 +1103,7 @@
               (ClassPackage.get_classtab thy);
     val deftab = extract_defs thy;
     val insttab = mk_insttab thy;
-    val overltabs = mk_overltabs thy deftab;
+    val overltabs = mk_overltabs thy;
     val dtcontab = mk_dtcontab thy;
     val clsmemtab = mk_clsmemtab thy;
   in ((deftab, clsmemtab), (insttab, overltabs, dtcontab)) end;
@@ -1109,9 +1117,9 @@
   map_codegen_data (fn (modl, gens, target_data, logic_data) =>
     (f modl, gens, target_data, logic_data));
 
-fun expand_module init gen thy =
+fun expand_module init gen arg thy =
   (#modl o CodegenData.get) thy
-  |> start_transact init (gen thy (mk_tabs thy))
+  |> start_transact init (gen thy (mk_tabs thy) arg)
   |-> (fn x:'a => fn modl => (x, map_module (K modl) thy));
 
 fun rename_inconsistent thy =
@@ -1154,9 +1162,18 @@
 fun codegen_incr t thy =
   thy
   |> `(#modl o CodegenData.get)
-  ||>> expand_module NONE (fn thy => fn tabs => exprsgen_term thy tabs [t])
+  ||>> expand_module NONE exprsgen_term [t]
   ||>> `(#modl o CodegenData.get)
-  |-> (fn ((modl_old, _), modl_new) => pair (CodegenThingol.diff_module (modl_new, modl_old)));
+  |-> (fn ((modl_old, [t]), modl_new) => pair (t, CodegenThingol.diff_module (modl_new, modl_old)));
+
+val is_dtcon = has_nsp nsp_dtcon;
+
+fun consts_of_idfs thy =
+  let
+    val tabs = mk_tabs thy;
+  in
+    map (fn idf => (idf, (the o recconst_of_idf thy tabs) idf))
+  end;
 
 fun get_ml_fun_datatype thy resolv =
   let
@@ -1177,21 +1194,13 @@
 fun read_typ thy =
   Sign.read_typ (thy, K NONE);
 
-fun read_const thy (raw_c, raw_ty) =
-  let
-    val c = Sign.intern_const thy raw_c;
-    val _ = if Sign.declared_const thy c
-      then ()
-      else error ("no such constant: " ^ quote c);
-    val ty = case raw_ty
-     of NONE => Sign.the_const_type thy c
-      | SOME raw_ty => read_typ thy raw_ty;
-  in (c, ty) end;
+fun read_const thy =
+  (dest_Const o Sign.read_term thy);
 
 fun read_quote get reader gen raw thy =
   thy
   |> expand_module ((SOME o get) thy)
-       (fn thy => fn tabs => (gen thy tabs o single o reader thy) raw)
+       (fn thy => fn tabs => gen thy tabs o single o reader thy) raw
   |-> (fn [x] => pair x);
 
 fun gen_add_prim prep_name prep_primdef raw_name (target, raw_primdef) thy =
@@ -1337,10 +1346,9 @@
 fun generate_code (SOME raw_consts) thy =
       let
         val consts = map (read_const thy) raw_consts;
-        fun generate thy tabs = fold_map (ensure_def_const thy tabs) consts
       in
         thy
-        |> expand_module NONE generate
+        |> expand_module NONE (fold_map oo ensure_def_const) consts
         |-> (fn cs => pair (SOME cs))
       end
   | generate_code NONE thy =
@@ -1381,7 +1389,7 @@
 
 val generateP =
   OuterSyntax.command generateK "generate executable code for constants" K.thy_decl (
-    Scan.repeat1 (P.name -- Scan.option (P.$$$ "::" |-- P.typ))
+    Scan.repeat1 P.term
     >> (fn raw_consts =>
           Toplevel.theory (generate_code (SOME raw_consts) #> snd))
   );
@@ -1389,7 +1397,7 @@
 val serializeP =
   OuterSyntax.command serializeK "serialize executable code for constants" K.thy_decl (
     P.name
-    -- Scan.option (Scan.repeat1 (P.name -- Scan.option (P.$$$ "::" |-- P.typ)))
+    -- Scan.option (Scan.repeat1 P.term)
     #-> (fn (target, raw_consts) =>
           P.$$$ "("
           |-- get_serializer target
@@ -1423,7 +1431,7 @@
 
 val primconstP =
   OuterSyntax.command primconstK "define target-lanugage specific constant" K.thy_decl (
-    (P.xname -- Scan.option (P.$$$ "::" |-- P.typ))
+    P.term
     -- Scan.repeat1 (P.name -- P.text)
       >> (fn (raw_const, primdefs) =>
             (Toplevel.theory oo fold) (add_prim_const raw_const) primdefs)
@@ -1456,7 +1464,7 @@
 val syntax_constP =
   OuterSyntax.command syntax_constK "define code syntax for constant" K.thy_decl (
     Scan.repeat1 (
-      (P.xname -- Scan.option (P.$$$ "::" |-- P.typ))
+      P.term
       #-> (fn raw_const => Scan.repeat1 (
              P.name -- parse_syntax_const raw_const
           ))