--- a/src/Tools/nbe.ML Fri Oct 19 16:18:00 2007 +0200
+++ b/src/Tools/nbe.ML Fri Oct 19 16:20:27 2007 +0200
@@ -7,22 +7,23 @@
signature NBE =
sig
+ val norm_conv: cterm -> thm
+ val norm_term: theory -> term -> term
+
datatype Univ =
Const of string * Univ list (*named (uninterpreted) constants*)
| Free of string * Univ list
+ | DFree of string (*free (uninterpreted) dictionary parameters*)
| BVar of int * Univ list
| Abs of (int * (Univ list -> Univ)) * Univ list;
- val free: string -> Univ list -> Univ (*free (uninterpreted) variables*)
- val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
- (*abstractions as functions*)
+ val free: string -> Univ (*free (uninterpreted) variables*)
val app: Univ -> Univ -> Univ (*explicit application*)
+ val abs: int -> (Univ list -> Univ) -> Univ
+ (*abstractions as closures*)
val univs_ref: (unit -> Univ list) ref
val lookup_fun: string -> Univ
- val norm_conv: cterm -> thm
- val norm_term: theory -> term -> term
-
val trace: bool ref
val setup: theory -> theory
end;
@@ -59,41 +60,25 @@
datatype Univ =
Const of string * Univ list (*named (uninterpreted) constants*)
| Free of string * Univ list (*free variables*)
+ | DFree of string (*free (uninterpreted) dictionary parameters*)
| BVar of int * Univ list (*bound named variables*)
| Abs of (int * (Univ list -> Univ)) * Univ list
(*abstractions as closures*);
(* constructor functions *)
-val free = curry Free;
-fun abs n f ts = Abs ((n, f), ts);
+fun free v = Free (v, []);
+fun abs n f = Abs ((n, f), []);
fun app (Abs ((1, f), xs)) x = f (x :: xs)
| app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
| app (Const (name, args)) x = Const (name, x :: args)
| app (Free (name, args)) x = Free (name, x :: args)
| app (BVar (name, args)) x = BVar (name, x :: args);
-(* global functions store *)
+(* universe graph *)
-structure Nbe_Functions = CodeDataFun
-(
- type T = Univ Graph.T;
- val empty = Graph.empty;
- fun merge _ = Graph.merge (K true);
- fun purge _ NONE _ = Graph.empty
- | purge NONE _ _ = Graph.empty
- | purge (SOME thy) (SOME cs) gr = Graph.empty
- (*let
- val cs_exisiting =
- map_filter (CodeName.const_rev thy) (Graph.keys gr);
- val dels = (Graph.all_preds gr
- o map (CodeName.const thy)
- o filter (member (op =) cs_exisiting)
- ) cs;
- in Graph.del_nodes dels gr end*);
-);
-
-fun defined gr = can (Graph.get_node gr);
+type univ_gr = Univ option Graph.T;
+val compiled : univ_gr -> string -> bool = can o Graph.get_node;
(* sandbox communication *)
@@ -101,30 +86,28 @@
local
-val gr_ref = ref NONE : Nbe_Functions.T option ref;
+val gr_ref = ref NONE : univ_gr option ref;
-fun compile tab raw_s = NAMED_CRITICAL "nbe" (fn () =>
+fun compile gr raw_s = NAMED_CRITICAL "nbe" (fn () =>
let
val _ = univs_ref := (fn () => []);
val s = "Nbe.univs_ref := " ^ raw_s;
val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
- val _ = gr_ref := SOME tab;
+ val _ = gr_ref := SOME gr;
val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
(!trace) s;
val _ = gr_ref := NONE;
in !univs_ref end);
+
in
-fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () =>
- case ! gr_ref
+fun lookup_fun s = NAMED_CRITICAL "nbe" (fn () => case ! gr_ref
of NONE => error "compile_univs"
- | SOME gr => Graph.get_node gr s);
+ | SOME gr => the (Graph.get_node gr s));
-fun compile_univs tab ([], _) = []
- | compile_univs tab (cs, raw_s) = case compile tab raw_s ()
- of [] => error "compile_univs"
- | univs => cs ~~ univs;
+fun compile_univs gr ([], _) = []
+ | compile_univs gr (cs, raw_s) = cs ~~ compile gr raw_s ();
end; (*local*)
@@ -135,7 +118,8 @@
infix 9 `$` `$$`;
fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
-fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
+fun e `$$` [] = e
+ | e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
fun ml_Val v s = "val " ^ v ^ " = " ^ s;
@@ -167,21 +151,25 @@
val prefix = "Nbe.";
val name_const = prefix ^ "Const";
val name_free = prefix ^ "free";
+ val name_dfree = prefix ^ "DFree";
val name_abs = prefix ^ "abs";
val name_app = prefix ^ "app";
val name_lookup_fun = prefix ^ "lookup_fun";
in
-fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
+fun nbe_const c ts =
+ name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
-fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
+fun nbe_free v = name_free `$` ML_Syntax.print_string v;
+fun nbe_dfree v = name_dfree `$` ML_Syntax.print_string v;
+fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
fun nbe_bound v = "v_" ^ v;
fun nbe_apps e es =
Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
fun nbe_abss 0 f = f `$` ml_list []
- | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
+ | nbe_abss n f = name_abs `$$` [string_of_int n, f];
fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
@@ -193,7 +181,12 @@
(* greetings to Tarski *)
-fun assemble_iterm thy is_fun num_args =
+fun assemble_idict (DictConst (inst, dss)) =
+ nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss)
+ | assemble_idict (DictVar (supers, (v, (n, _)))) =
+ fold (fn super => nbe_apps (nbe_fun super) o single) supers (nbe_dict v n);
+
+fun assemble_iterm is_fun num_args =
let
fun of_iterm t =
let
@@ -206,7 +199,8 @@
end else nbe_const c ts
| NONE => if is_fun c then nbe_apps (nbe_fun c) ts
else nbe_const c ts
- and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts
+ and of_iapp (IConst (c, (dss, _))) ts = of_iconst c
+ (ts @ rev ((maps o map) assemble_idict dss))
| of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
| of_iapp ((v, _) `|-> t) ts =
nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
@@ -215,112 +209,123 @@
@ [("_", of_iterm t0)])) ts
in of_iterm end;
-fun assemble_fun thy is_fun num_args (c, eqns) =
+fun assemble_fun gr num_args (c, (vs, eqns)) =
let
- val assemble_arg = assemble_iterm thy (K false) (K NONE);
- val assemble_rhs = assemble_iterm thy is_fun num_args;
+ val assemble_arg = assemble_iterm (K false) (K NONE);
+ val assemble_rhs = assemble_iterm (is_some o Graph.get_node gr) num_args;
+ val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
+ |> rev;
fun assemble_eqn (args, rhs) =
- ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
- val default_params = map nbe_bound
- (Name.invent_list [] "a" ((the o num_args) c));
+ ([ml_list (map assemble_arg (rev args) @ dict_params)], assemble_rhs rhs);
+ val default_params = map nbe_bound (Name.invent_list [] "a" ((the o num_args) c));
val default_eqn = ([ml_list default_params], nbe_const c default_params);
in map assemble_eqn eqns @ [default_eqn] end;
-fun assemble_eqnss thy is_fun ([], deps) = ([], "")
- | assemble_eqnss thy is_fun (eqnss, deps) =
+fun assemble_eqnss gr ([], deps) = ([], "")
+ | assemble_eqnss gr (eqnss, deps) =
let
val cs = map fst eqnss;
- val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
- val funs = fold (fold (CodeThingol.fold_constnames
- (insert (op =))) o map snd o snd) eqnss [];
- val bind_funs = map nbe_lookup (filter is_fun funs);
+ val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
+ length (maps snd vs) + length args) eqnss;
+ val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
- (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
+ (assemble_fun gr (AList.lookup (op =) num_args)) eqnss);
val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
|> ml_delay;
- in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
+ in (cs, ml_Let (bind_deps @ [bind_locals]) result) end;
-fun assemble_eval thy is_fun (((vs, ty), t), deps) =
- let
- val funs = CodeThingol.fold_constnames (insert (op =)) t [];
- val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
- val bind_funs = map nbe_lookup (filter is_fun funs);
- val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
- assemble_iterm thy is_fun (K NONE) t)])];
- val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)]
- |> ml_delay;
- in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
+fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) =
+ []
+ | eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) =
+ [(const, (vs, map fst eqns))]
+ | eqns_of_stmt (_, CodeThingol.Datatypecons _) =
+ []
+ | eqns_of_stmt (_, CodeThingol.Datatype _) =
+ []
+ | eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) =
+ let
+ val names = map snd superclasses @ map fst classops;
+ val params = Name.invent_list [] "d" (length names);
+ fun mk (k, name) =
+ (name, ([(v, [])],
+ [([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))]));
+ in map_index mk names end
+ | eqns_of_stmt (_, CodeThingol.Classrel _) =
+ []
+ | eqns_of_stmt (_, CodeThingol.Classparam _) =
+ []
+ | eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) =
+ [(inst, (arities, [([], IConst (class, ([], [])) `$$
+ map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts
+ @ map (IConst o snd o fst) instops)]))];
-fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) =
- NONE
- | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) =
- SOME ((name, map fst eqns), deps)
- | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
- NONE
- | eqns_of_stmt ((_, CodeThingol.Datatype _), _) =
- NONE
- | eqns_of_stmt ((_, CodeThingol.Class _), _) =
- NONE
- | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
- NONE
- | eqns_of_stmt ((_, CodeThingol.Classparam _), _) =
- NONE
- | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
- NONE;
+fun compile_stmts stmts_deps =
+ let
+ val names = map (fst o fst) stmts_deps;
+ val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
+ val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
+ val compiled_deps = names_deps |> maps snd |> distinct (op =) |> subtract (op =) names;
+ fun compile gr = (eqnss, compiled_deps) |> assemble_eqnss gr |> compile_univs gr |> rpair gr;
+ in
+ fold (fn name => Graph.new_node (name, NONE)) names
+ #> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
+ #> compile
+ #-> fold (fn (name, univ) => Graph.map_node name (K (SOME univ)))
+ end;
-fun compile_stmts thy is_fun =
- map_filter eqns_of_stmt
- #> split_list
- #> assemble_eqnss thy is_fun
- #> compile_univs (Nbe_Functions.get thy);
+fun ensure_stmts code =
+ let
+ fun add_stmts names gr = if exists (compiled gr) names then gr else gr
+ |> compile_stmts (map (fn name => ((name, Graph.get_node code name),
+ Graph.imm_succs code name)) names);
+ in fold_rev add_stmts (Graph.strong_conn code) end;
-fun eval_term thy is_fun =
- assemble_eval thy is_fun
- #> compile_univs (Nbe_Functions.get thy)
+fun assemble_eval gr (((vs, ty), t), deps) =
+ let
+ val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
+ val bind_deps = map nbe_lookup (filter (is_some o Graph.get_node gr) deps);
+ val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
+ |> rev;
+ val bind_value = ml_fundefs [(nbe_value,
+ [([ml_list (map nbe_bound frees @ dict_params)],
+ assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])];
+ val result = ml_list [nbe_value `$` ml_list
+ (map nbe_free frees @ map nbe_dfree dict_params)]
+ |> ml_delay;
+ in ([nbe_value], ml_Let (bind_deps @ [bind_value]) result) end;
+
+fun eval_term gr =
+ assemble_eval gr
+ #> compile_univs gr
#> the_single
#> snd;
-(** compilation and evaluation **)
-
-(* ensure global functions *)
-
-fun ensure_funs thy code =
- let
- fun add_dep (name, dep) gr =
- if can (Graph.get_node gr) name andalso can (Graph.get_node gr) dep
- then Graph.add_edge (name, dep) gr else gr;
- fun compile' stmts gr =
- let
- val compiled = compile_stmts thy (defined gr) stmts;
- val names = map (fst o fst) stmts;
- val deps = maps snd stmts;
- in
- Nbe_Functions.change thy (fold Graph.new_node compiled
- #> fold (fn name => fold (curry add_dep name) deps) names)
- end;
- val nbe_gr = Nbe_Functions.get thy;
- val stmtss = rev (Graph.strong_conn code)
- |> (map o map_filter) (fn name => if defined nbe_gr name
- then NONE
- else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
- |> filter_out null
- in fold compile' stmtss nbe_gr end;
+(** evaluation **)
(* reification *)
fun term_of_univ thy t =
let
+ fun take_until f [] = []
+ | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
+ fun is_dict (Const (c, _)) =
+ (is_some o CodeName.class_rev thy) c
+ orelse (is_some o CodeName.classrel_rev thy) c
+ orelse (is_some o CodeName.instance_rev thy) c
+ | is_dict (DFree _) = true
+ | is_dict _ = false;
fun of_apps bounds (t, ts) =
fold_map (of_univ bounds) ts
#>> (fn ts' => list_comb (t, rev ts'))
and of_univ bounds (Const (name, ts)) typidx =
let
+ val ts' = take_until is_dict ts;
val SOME c = CodeName.const_rev thy name;
val T = Code.default_typ thy c;
val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
val typidx' = typidx + maxidx_of_typ T' + 1;
- in of_apps bounds (Term.Const (c, T'), ts) typidx' end
+ in of_apps bounds (Term.Const (c, T'), ts') typidx' end
| of_univ bounds (Free (name, ts)) typidx =
of_apps bounds (Term.Free (name, dummyT), ts) typidx
| of_univ bounds (BVar (name, ts)) typidx =
@@ -331,6 +336,33 @@
|-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
in of_univ 0 t 0 |> fst end;
+(* function store *)
+
+structure Nbe_Functions = CodeDataFun
+(
+ type T = univ_gr;
+ val empty = Graph.empty;
+ fun merge _ = Graph.merge (K true);
+ fun purge _ NONE _ = Graph.empty
+ | purge NONE _ _ = Graph.empty
+ | purge (SOME thy) (SOME cs) gr =
+ let
+ val cs_exisiting =
+ map_filter (CodeName.const_rev thy) (Graph.keys gr);
+ val dels = (Graph.all_preds gr
+ o map (CodeName.const thy)
+ o filter (member (op =) cs_exisiting)
+ ) cs;
+ in Graph.del_nodes dels gr end;
+);
+
+(* compilation, evaluation and reification *)
+
+fun compile_eval thy code vs_ty_t deps =
+ (vs_ty_t, deps)
+ |> eval_term (Nbe_Functions.change thy (ensure_stmts code))
+ |> term_of_univ thy;
+
(* evaluation with type reconstruction *)
fun eval thy code t vs_ty_t deps =
@@ -349,16 +381,14 @@
error ("Illegal schematic type variables in normalized term: "
^ setmp show_types true (Sign.string_of_term thy) t);
in
- (vs_ty_t, deps)
- |> eval_term thy (defined (ensure_funs thy code))
- |> term_of_univ thy
+ compile_eval thy code vs_ty_t deps
|> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
|> anno_vars
|> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
|> constrain
|> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
+ |> tracing (fn t => "---\n")
|> check_tvars
- |> tracing (fn _ => "---\n")
end;
(* evaluation oracle *)