improved implementation
authorhaftmann
Fri, 18 Jan 2008 08:30:12 +0100
changeset 25924 f974a1c64348
parent 25923 5fe4b543512e
child 25925 3dc4acca4388
improved implementation
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Thu Jan 17 21:56:33 2008 +0100
+++ b/src/Tools/nbe.ML	Fri Jan 18 08:30:12 2008 +0100
@@ -12,17 +12,19 @@
 
   datatype Univ =
       Const of string * Univ list            (*named (uninterpreted) constants*)
-    | Free of string * Univ list
-    | DFree of string                        (*free (uninterpreted) dictionary parameters*)
+    | Free of string * Univ list             (*free (uninterpreted) variables*)
+    | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
     | BVar of int * Univ list
     | Abs of (int * (Univ list -> Univ)) * Univ list;
-  val free: string -> Univ                   (*free (uninterpreted) variables*)
-  val app: Univ -> Univ -> Univ              (*explicit application*)
+  val apps: Univ -> Univ list -> Univ        (*explicit applications*)
   val abs: int -> (Univ list -> Univ) -> Univ
-                                             (*abstractions as closures*)
+                                            (*abstractions as closures*)
 
   val univs_ref: (unit -> Univ list -> Univ list) option ref
+  val norm_invoke: theory -> CodeThingol.code -> term
+    -> CodeThingol.typscheme * CodeThingol.iterm -> string list -> thm
   val trace: bool ref
+
   val setup: theory -> theory
 end;
 
@@ -58,20 +60,23 @@
 datatype Univ =
     Const of string * Univ list        (*named (uninterpreted) constants*)
   | Free of string * Univ list         (*free variables*)
-  | DFree of string                    (*free (uninterpreted) dictionary parameters*)
+  | DFree of string * int              (*free (uninterpreted) dictionary parameters*)
   | BVar of int * Univ list            (*bound named variables*)
   | Abs of (int * (Univ list -> Univ)) * Univ list
                                       (*abstractions as closures*);
 
 (* constructor functions *)
 
-fun free v = Free (v, []);
 fun abs n f = Abs ((n, f), []);
-fun app (Abs ((1, f), xs)) x = f (x :: xs)
-  | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
-  | app (Const (name, args)) x = Const (name, x :: args)
-  | app (Free (name, args)) x = Free (name, x :: args)
-  | app (BVar (name, args)) x = BVar (name, x :: args);
+fun apps (Abs ((n, f), xs)) ys = let val k = n - length ys in
+      if k = 0 then f (ys @ xs)
+      else if k < 0 then
+        let val (zs, ws) = chop (~ k) ys
+        in apps (f (ws @ xs)) zs end
+      else Abs ((k, f), ys @ xs) end (*note: reverse convention also for apps!*)
+  | apps (Const (name, xs)) ys = Const (name, ys @ xs)
+  | apps (Free (name, xs)) ys = Free (name, ys @ xs)
+  | apps (BVar (name, xs)) ys = BVar (name, ys @ xs);
 
 (* universe graph *)
 
@@ -114,29 +119,27 @@
 local
   val prefix =          "Nbe.";
   val name_const =      prefix ^ "Const";
-  val name_free =       prefix ^ "free";
-  val name_dfree =      prefix ^ "DFree";
   val name_abs =        prefix ^ "abs";
-  val name_app =        prefix ^ "app";
-  val name_lookup_fun = prefix ^ "lookup_fun";
+  val name_apps =       prefix ^ "apps";
 in
 
-fun nbe_const c ts =
-  name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
-fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
-fun nbe_free v = name_free `$` ML_Syntax.print_string v;
-fun nbe_dfree v = name_dfree `$` ML_Syntax.print_string v;
+fun nbe_fun' c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
+val nbe_fun = nbe_fun'; (*FIXME!*)
 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
 fun nbe_bound v = "v_" ^ v;
+val nbe_value = "";
 
-fun nbe_apps e es =
-  Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
+(*note: these three are the "turning spots" where proper argument order is established!*)
+fun nbe_apps t [] = t
+  | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
+fun nbe_apps_local c ts = nbe_fun c `$` ml_list (rev ts);
+fun nbe_apps_constr c ts =
+  name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list (rev ts) ^ ")");
+
 
 fun nbe_abss 0 f = f `$` ml_list []
   | nbe_abss n f = name_abs `$$` [string_of_int n, f];
 
-val nbe_value = "value";
-
 end;
 
 open BasicCodeThingol;
@@ -154,26 +157,31 @@
 
 (* code generation *)
 
-fun assemble_idict (DictConst (inst, dss)) =
-      nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss)
-  | assemble_idict (DictVar (supers, (v, (n, _)))) =
-      fold_rev (fn super => nbe_apps (nbe_fun super) o single) supers (nbe_dict v n);
+datatype const_kind = Local of int | Global | Constr;
 
-fun assemble_iterm is_fun num_args =
+fun assemble_constapp kind c dss ts = 
+      let
+        val ts' = (maps o map) (assemble_idict kind) dss @ ts;
+      in case kind c
+       of Local n => if n <= length ts'
+            then let val (ts1, ts2) = chop n ts'
+            in nbe_apps (nbe_apps_local c ts1) ts2
+            end else nbe_apps (nbe_abss n (nbe_fun c)) ts'
+        | Global => nbe_apps (nbe_fun c) ts'
+        | Constr => nbe_apps_constr c ts'
+      end
+and assemble_idict kind (DictConst (inst, dss)) =
+      assemble_constapp kind inst dss []
+  | assemble_idict kind (DictVar (supers, (v, (n, _)))) =
+      fold_rev (fn super => assemble_constapp kind super [] o single) supers (nbe_dict v n);
+
+fun assemble_iterm kind =
   let
     fun of_iterm t =
       let
         val (t', ts) = CodeThingol.unfold_app t
-      in of_iapp t' (fold (cons o of_iterm) ts []) end
-    and of_iconst c ts = case num_args c
-     of SOME n => if n <= length ts
-          then let val (args2, args1) = chop (length ts - n) ts
-          in nbe_apps (nbe_fun c `$` ml_list args1) args2
-          end else nbe_const c ts
-      | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
-          else nbe_const c ts
-    and of_iapp (IConst (c, (dss, _))) ts = of_iconst c
-          (ts @ rev ((maps o map) assemble_idict dss))
+      in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
+    and of_iapp (IConst (c, (dss, _))) ts = assemble_constapp kind c dss ts
       | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
       | of_iapp ((v, _) `|-> t) ts =
           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
@@ -182,17 +190,16 @@
             @ [("_", of_iterm t0)])) ts
   in of_iterm end;
 
-fun assemble_fun gr num_args (c, (vs, eqns)) =
+fun assemble_eqns kind (c, (vs, eqns)) =
   let
-    val assemble_arg = assemble_iterm (K false) (K NONE);
-    val assemble_rhs = assemble_iterm (is_some o Graph.get_node gr) num_args;
-    val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
-      |> rev;
+    val dict_args = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
+    val assemble_arg = assemble_iterm (K Constr);
+    val assemble_rhs = assemble_iterm kind;
     fun assemble_eqn (args, rhs) =
-      ([ml_list (map assemble_arg (rev args) @ dict_params)], assemble_rhs rhs);
-    val default_params = map nbe_bound (Name.invent_list [] "a" ((the o num_args) c));
-    val default_eqn = ([ml_list default_params], nbe_const c default_params);
-  in map assemble_eqn eqns @ [default_eqn] end;
+      ([ml_list (rev (dict_args @ map assemble_arg args))], assemble_rhs rhs);
+    val default_args = dict_args @ map nbe_bound (Name.invent_list [] "a" ((length o fst o hd) eqns));
+    val default_eqn = ([ml_list (rev default_args)], nbe_apps_constr c default_args);
+  in (nbe_fun' c, map assemble_eqn eqns @ [default_eqn]) end;
 
 fun assemble_eqnss gr deps [] = ([], ("", []))
   | assemble_eqnss gr deps eqnss =
@@ -200,10 +207,13 @@
         val cs = map fst eqnss;
         val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
           length (maps snd vs) + length args) eqnss;
+        fun kind c = case AList.lookup (op =) num_args c
+         of SOME n => Local n
+          | NONE => if (is_some o Option.join o try (Graph.get_node gr)) c
+              then Global else Constr;
         val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
-        val bind_deps = ml_list (map nbe_fun deps');
-        val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
-          (assemble_fun gr (AList.lookup (op =) num_args)) eqnss);
+        val bind_deps = ml_list (map nbe_fun' deps');
+        val bind_locals = ml_fundefs (map (assemble_eqns kind) eqnss);
         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
         val arg_deps = map (the o Graph.get_node gr) deps';
       in (cs, (ml_abs bind_deps (ml_Let [bind_locals] result), arg_deps)) end;
@@ -263,25 +273,17 @@
           Graph.imm_succs code name)) names);
   in fold_rev add_stmts (Graph.strong_conn code) end;
 
-fun assemble_eval gr deps ((vs, ty), t) =
-  let
-    val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
-    val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
-    val bind_deps = ml_list (map nbe_fun deps');
-    val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
-      |> rev;
-    val bind_value = ml_fundefs [(nbe_value,
-      [([ml_list (map nbe_bound frees @ dict_params)],
-        assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])];
-    val result = ml_list [nbe_value `$` ml_list
-      (map nbe_free frees @ map nbe_dfree dict_params)];
-    val arg_deps = map (the o Graph.get_node gr) deps';
-  in (ml_abs bind_deps (ml_Let [bind_value] result), arg_deps) end;
-
-fun eval_term gr deps t' =
-  let
-    val (s, args) = assemble_eval gr deps t';
-  in the_single (compile s args) end;
+fun eval_term gr deps ((vs, ty), t) =
+  let 
+    val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []
+    val frees' = map (fn v => Free (v, [])) frees;
+    val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
+  in
+    (nbe_value, (vs, [(map IVar frees, t)]))
+    |> singleton (compile_eqnss gr deps)
+    |> snd
+    |> (fn t => apps t (rev (dict_frees @ frees')))
+  end;
 
 
 (** evaluation **)
@@ -315,7 +317,7 @@
           of_apps bounds (Bound (bounds - name - 1), ts) typidx
       | of_univ bounds (t as Abs _) typidx =
           typidx
-          |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
+          |> of_univ (bounds + 1) (apps t [BVar (bounds, [])])
           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   in of_univ 0 t 0 |> fst end;