now employing dictionaries
authorhaftmann
Fri, 19 Oct 2007 16:20:27 +0200
changeset 25101 cae0f68b693b
parent 25100 fe9632d914c7
child 25102 db3e412c4cb1
now employing dictionaries
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Fri Oct 19 16:18:00 2007 +0200
+++ b/src/Tools/nbe.ML	Fri Oct 19 16:20:27 2007 +0200
@@ -7,22 +7,23 @@
 
 signature NBE =
 sig
+  val norm_conv: cterm -> thm
+  val norm_term: theory -> term -> term
+
   datatype Univ = 
       Const of string * Univ list            (*named (uninterpreted) constants*)
     | Free of string * Univ list
+    | DFree of string                        (*free (uninterpreted) dictionary parameters*)
     | BVar of int * Univ list
     | Abs of (int * (Univ list -> Univ)) * Univ list;
-  val free: string -> Univ list -> Univ       (*free (uninterpreted) variables*)
-  val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
-                                            (*abstractions as functions*)
+  val free: string -> Univ                   (*free (uninterpreted) variables*)
   val app: Univ -> Univ -> Univ              (*explicit application*)
+  val abs: int -> (Univ list -> Univ) -> Univ
+                                             (*abstractions as closures*)
 
   val univs_ref: (unit -> Univ list) ref 
   val lookup_fun: string -> Univ
 
-  val norm_conv: cterm -> thm
-  val norm_term: theory -> term -> term
-
   val trace: bool ref
   val setup: theory -> theory
 end;
@@ -59,41 +60,25 @@
 datatype Univ = 
     Const of string * Univ list        (*named (uninterpreted) constants*)
   | Free of string * Univ list         (*free variables*)
+  | DFree of string                    (*free (uninterpreted) dictionary parameters*)
   | BVar of int * Univ list            (*bound named variables*)
   | Abs of (int * (Univ list -> Univ)) * Univ list
                                       (*abstractions as closures*);
 
 (* constructor functions *)
 
-val free = curry Free;
-fun abs n f ts = Abs ((n, f), ts);
+fun free v = Free (v, []);
+fun abs n f = Abs ((n, f), []);
 fun app (Abs ((1, f), xs)) x = f (x :: xs)
   | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
   | app (Const (name, args)) x = Const (name, x :: args)
   | app (Free (name, args)) x = Free (name, x :: args)
   | app (BVar (name, args)) x = BVar (name, x :: args);
 
-(* global functions store *)
+(* universe graph *)
 
-structure Nbe_Functions = CodeDataFun
-(
-  type T = Univ Graph.T;
-  val empty = Graph.empty;
-  fun merge _ = Graph.merge (K true);
-  fun purge _ NONE _ = Graph.empty
-    | purge NONE _ _ = Graph.empty
-    | purge (SOME thy) (SOME cs) gr = Graph.empty
-        (*let
-          val cs_exisiting =
-            map_filter (CodeName.const_rev thy) (Graph.keys gr);
-          val dels = (Graph.all_preds gr
-              o map (CodeName.const thy)
-              o filter (member (op =) cs_exisiting)
-            ) cs;
-        in Graph.del_nodes dels gr end*);
-);
-
-fun defined gr = can (Graph.get_node gr);
+type univ_gr = Univ option Graph.T;
+val compiled : univ_gr -> string -> bool = can o Graph.get_node;
 
 (* sandbox communication *)
 
@@ -101,30 +86,28 @@
 
 local
 
-val gr_ref = ref NONE : Nbe_Functions.T option ref;
+val gr_ref = ref NONE : univ_gr option ref;
 
-fun compile tab raw_s = NAMED_CRITICAL "nbe" (fn () =>
+fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () =>
   let
     val _ = univs_ref := (fn () => []);
     val s = "Nbe.univs_ref := " ^ raw_s;
     val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
-    val _ = gr_ref := SOME tab;
+    val _ = gr_ref := SOME gr;
     val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
       Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
       (!trace) s;
     val _ = gr_ref := NONE;
   in !univs_ref end);
+
 in
 
-fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () =>
- case ! gr_ref
+fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref
  of NONE => error "compile_univs"
-  | SOME gr => Graph.get_node gr s);
+  | SOME gr => the (Graph.get_node gr s));
 
-fun compile_univs tab ([], _) = []
-  | compile_univs tab (cs, raw_s) = case compile tab raw_s ()
-     of [] => error "compile_univs"
-      | univs => cs ~~ univs;
+fun compile_univs gr ([], _) = []
+  | compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s ();
 
 end; (*local*)
 
@@ -135,7 +118,8 @@
 
 infix 9 `$` `$$`;
 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
-fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
+fun e `$$` [] = e
+  | e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
 fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
 
 fun ml_Val v s = "val " ^ v ^ " = " ^ s;
@@ -167,21 +151,25 @@
   val prefix =          "Nbe.";
   val name_const =      prefix ^ "Const";
   val name_free =       prefix ^ "free";
+  val name_dfree =      prefix ^ "DFree";
   val name_abs =        prefix ^ "abs";
   val name_app =        prefix ^ "app";
   val name_lookup_fun = prefix ^ "lookup_fun";
 in
 
-fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
+fun nbe_const c ts =
+  name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
-fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
+fun nbe_free v = name_free `$` ML_Syntax.print_string v;
+fun nbe_dfree v = name_dfree `$` ML_Syntax.print_string v;
+fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
 fun nbe_bound v = "v_" ^ v;
 
 fun nbe_apps e es =
   Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
 
 fun nbe_abss 0 f = f `$` ml_list []
-  | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
+  | nbe_abss n f = name_abs `$$` [string_of_int n, f];
 
 fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
 
@@ -193,7 +181,12 @@
 
 (* greetings to Tarski *)
 
-fun assemble_iterm thy is_fun num_args =
+fun assemble_idict (DictConst (inst, dss)) =
+      nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss)
+  | assemble_idict (DictVar (supers, (v, (n, _)))) =
+      fold (fn super => nbe_apps (nbe_fun super) o single) supers (nbe_dict v n);
+
+fun assemble_iterm is_fun num_args =
   let
     fun of_iterm t =
       let
@@ -206,7 +199,8 @@
           end else nbe_const c ts
       | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
           else nbe_const c ts
-    and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts
+    and of_iapp (IConst (c, (dss, _))) ts = of_iconst c
+          (ts @ rev ((maps o map) assemble_idict dss))
       | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
       | of_iapp ((v, _) `|-> t) ts =
           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
@@ -215,112 +209,123 @@
             @ [("_", of_iterm t0)])) ts
   in of_iterm end;
 
-fun assemble_fun thy is_fun num_args (c, eqns) =
+fun assemble_fun gr num_args (c, (vs, eqns)) =
   let
-    val assemble_arg = assemble_iterm thy (K false) (K NONE);
-    val assemble_rhs = assemble_iterm thy is_fun num_args;
+    val assemble_arg = assemble_iterm (K false) (K NONE);
+    val assemble_rhs = assemble_iterm (is_some o Graph.get_node gr) num_args;
+    val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
+      |> rev;
     fun assemble_eqn (args, rhs) =
-      ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
-    val default_params = map nbe_bound
-      (Name.invent_list [] "a" ((the o num_args) c));
+      ([ml_list (map assemble_arg (rev args) @ dict_params)], assemble_rhs rhs);
+    val default_params = map nbe_bound (Name.invent_list [] "a" ((the o num_args) c));
     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   in map assemble_eqn eqns @ [default_eqn] end;
 
-fun assemble_eqnss thy is_fun ([], deps) = ([], "")
-  | assemble_eqnss thy is_fun (eqnss, deps) =
+fun assemble_eqnss gr ([], deps) = ([], "")
+  | assemble_eqnss gr (eqnss, deps) =
       let
         val cs = map fst eqnss;
-        val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
-        val funs = fold (fold (CodeThingol.fold_constnames
-          (insert (op =))) o map snd o snd) eqnss [];
-        val bind_funs = map nbe_lookup (filter is_fun funs);
+        val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
+          length (maps snd vs) + length args) eqnss;
+        val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
-          (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
+          (assemble_fun gr (AList.lookup (op =) num_args)) eqnss);
         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
           |> ml_delay;
-      in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
+      in (cs, ml_Let (bind_deps @ [bind_locals]) result) end;
 
-fun assemble_eval thy is_fun (((vs, ty), t), deps) =
-  let
-    val funs = CodeThingol.fold_constnames (insert (op =)) t [];
-    val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
-    val bind_funs = map nbe_lookup (filter is_fun funs);
-    val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
-      assemble_iterm thy is_fun (K NONE) t)])];
-    val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)]
-      |> ml_delay;
-  in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
+fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) =
+      []
+  | eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) =
+      [(const, (vs, map fst eqns))]
+  | eqns_of_stmt (_, CodeThingol.Datatypecons _) =
+      []
+  | eqns_of_stmt (_, CodeThingol.Datatype _) =
+      []
+  | eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) =
+      let
+        val names = map snd superclasses @ map fst classops;
+        val params = Name.invent_list [] "d" (length names);
+        fun mk (k, name) =
+          (name, ([(v, [])],
+            [([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))]));
+      in map_index mk names end
+  | eqns_of_stmt (_, CodeThingol.Classrel _) =
+      []
+  | eqns_of_stmt (_, CodeThingol.Classparam _) =
+      []
+  | eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) =
+      [(inst, (arities, [([], IConst (class, ([], [])) `$$
+        map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts
+        @ map (IConst o snd o fst) instops)]))];
 
-fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) =
-      NONE
-  | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) =
-      SOME ((name, map fst eqns), deps)
-  | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
-      NONE
-  | eqns_of_stmt ((_, CodeThingol.Datatype _), _) =
-      NONE
-  | eqns_of_stmt ((_, CodeThingol.Class _), _) =
-      NONE
-  | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
-      NONE
-  | eqns_of_stmt ((_, CodeThingol.Classparam _), _) =
-      NONE
-  | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
-      NONE;
+fun compile_stmts stmts_deps =
+  let
+    val names = map (fst o fst) stmts_deps;
+    val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
+    val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
+    val compiled_deps = names_deps |> maps snd |> distinct (op =) |> subtract (op =) names;
+    fun compile gr = (eqnss, compiled_deps) |> assemble_eqnss gr |> compile_univs gr |> rpair gr;
+  in
+    fold (fn name => Graph.new_node (name, NONE)) names
+    #> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
+    #> compile
+    #-> fold (fn (name, univ) => Graph.map_node name (K (SOME univ)))
+  end;
 
-fun compile_stmts thy is_fun =
-  map_filter eqns_of_stmt
-  #> split_list
-  #> assemble_eqnss thy is_fun
-  #> compile_univs (Nbe_Functions.get thy);
+fun ensure_stmts code =
+  let
+    fun add_stmts names gr = if exists (compiled gr) names then gr else gr
+      |> compile_stmts (map (fn name => ((name, Graph.get_node code name),
+          Graph.imm_succs code name)) names);
+  in fold_rev add_stmts (Graph.strong_conn code) end;
 
-fun eval_term thy is_fun =
-  assemble_eval thy is_fun
-  #> compile_univs (Nbe_Functions.get thy)
+fun assemble_eval gr (((vs, ty), t), deps) =
+  let
+    val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
+    val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
+    val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
+      |> rev;
+    val bind_value = ml_fundefs [(nbe_value,
+      [([ml_list (map nbe_bound frees @ dict_params)],
+        assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])];
+    val result = ml_list [nbe_value `$` ml_list
+      (map nbe_free frees @ map nbe_dfree dict_params)]
+      |> ml_delay;
+  in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end;
+
+fun eval_term gr =
+  assemble_eval gr
+  #> compile_univs gr
   #> the_single
   #> snd;
 
 
-(** compilation and evaluation **)
-
-(* ensure global functions *)
-
-fun ensure_funs thy code =
-  let
-    fun add_dep (name, dep) gr =
-      if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep
-      then Graph.add_edge (name, dep) gr else gr;
-    fun compile' stmts gr =
-      let
-        val compiled = compile_stmts thy (defined gr) stmts;
-        val names = map (fst o fst) stmts;
-        val deps = maps snd stmts;
-      in
-        Nbe_Functions.change thy (fold Graph.new_node compiled
-          #> fold (fn name => fold (curry add_dep name) deps) names)
-      end;
-    val nbe_gr = Nbe_Functions.get thy;
-    val stmtss = rev (Graph.strong_conn code)
-      |> (map o map_filter) (fn name => if defined nbe_gr name
-           then NONE
-           else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
-      |> filter_out null
-  in fold compile' stmtss nbe_gr end;
+(** evaluation **)
 
 (* reification *)
 
 fun term_of_univ thy t =
   let
+    fun take_until f [] = []
+      | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
+    fun is_dict (Const (c, _)) =
+          (is_some o CodeName.class_rev thy) c
+          orelse (is_some o CodeName.classrel_rev thy) c
+          orelse (is_some o CodeName.instance_rev thy) c
+      | is_dict (DFree _) = true
+      | is_dict _ = false;
     fun of_apps bounds (t, ts) =
       fold_map (of_univ bounds) ts
       #>> (fn ts' => list_comb (t, rev ts'))
     and of_univ bounds (Const (name, ts)) typidx =
           let
+            val ts' = take_until is_dict ts;
             val SOME c = CodeName.const_rev thy name;
             val T = Code.default_typ thy c;
             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
             val typidx' = typidx + maxidx_of_typ T' + 1;
-          in of_apps bounds (Term.Const (c, T'), ts) typidx' end
+          in of_apps bounds (Term.Const (c, T'), ts') typidx' end
       | of_univ bounds (Free (name, ts)) typidx =
           of_apps bounds (Term.Free (name, dummyT), ts) typidx
       | of_univ bounds (BVar (name, ts)) typidx =
@@ -331,6 +336,33 @@
           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   in of_univ 0 t 0 |> fst end;
 
+(* function store *)
+
+structure Nbe_Functions = CodeDataFun
+(
+  type T = univ_gr;
+  val empty = Graph.empty;
+  fun merge _ = Graph.merge (K true);
+  fun purge _ NONE _ = Graph.empty
+    | purge NONE _ _ = Graph.empty
+    | purge (SOME thy) (SOME cs) gr =
+        let
+          val cs_exisiting =
+            map_filter (CodeName.const_rev thy) (Graph.keys gr);
+          val dels = (Graph.all_preds gr
+              o map (CodeName.const thy)
+              o filter (member (op =) cs_exisiting)
+            ) cs;
+        in Graph.del_nodes dels gr end;
+);
+
+(* compilation, evaluation and reification *)
+
+fun compile_eval thy code vs_ty_t deps =
+  (vs_ty_t, deps)
+  |> eval_term (Nbe_Functions.change thy (ensure_stmts code))
+  |> term_of_univ thy;
+
 (* evaluation with type reconstruction *)
 
 fun eval thy code t vs_ty_t deps =
@@ -349,16 +381,14 @@
       error ("Illegal schematic type variables in normalized term: "
         ^ setmp show_types true (Sign.string_of_term thy) t);
   in
-    (vs_ty_t, deps)
-    |> eval_term thy (defined (ensure_funs thy code))
-    |> term_of_univ thy
+    compile_eval thy code vs_ty_t deps
     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
     |> anno_vars
     |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
     |> constrain
     |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
+    |> tracing (fn t => "---\n")
     |> check_tvars
-    |> tracing (fn _ => "---\n")
   end;
 
 (* evaluation oracle *)