--- a/src/Tools/nbe.ML Fri Mar 28 14:13:36 2025 +0100
+++ b/src/Tools/nbe.ML Fri Mar 28 14:13:37 2025 +0100
@@ -255,9 +255,6 @@
val univs_cookie = (get_result, put_result, name_put);
-fun nbe_fun idx_of 0 (Code_Symbol.Constant "") = "nbe_value"
- | nbe_fun idx_of i sym = "c_" ^ string_of_int (idx_of sym)
- ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
fun nbe_bound v = "v_" ^ v;
fun nbe_bound_optional NONE = "_"
@@ -267,13 +264,8 @@
(*note: these three are the "turning spots" where proper argument order is established!*)
fun nbe_apps t [] = t
| nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
-fun nbe_apps_local idx_of i c ts = nbe_fun idx_of i c `$` ml_list (rev ts);
-fun nbe_apps_constr ctxt idx_of c ts =
- let
- val c' = if Config.get ctxt trace
- then string_of_int (idx_of c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)"
- else string_of_int (idx_of c);
- in name_const `$` ("(" ^ c' ^ ", " ^ ml_list (rev ts) ^ ")") end;
+fun nbe_apps_local c ts = c `$` ml_list (rev ts);
+fun nbe_apps_constr c ts = name_const `$` ("(" ^ c ^ ", " ^ ml_list (rev ts) ^ ")");
fun nbe_abss 0 f = f `$` ml_list []
| nbe_abss n f = name_abss `$$` [string_of_int n, f];
@@ -288,26 +280,37 @@
(* code generation *)
-fun assemble_eqnss ctxt idx_of deps eqnss =
+fun assemble_eqnss ctxt idx_of_const deps eqnss =
let
- fun prep_eqns (c, (vs, eqns)) =
+ fun prep_eqns (sym, (vs, eqns)) =
let
val dicts = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
val num_args = length dicts + ((length o fst o hd) eqns);
- in (c, (num_args, (dicts, eqns))) end;
+ in (sym, (num_args, (dicts, eqns))) end;
val eqnss' = map prep_eqns eqnss;
+ fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value"
+ | fun_ident i sym = "c_" ^ string_of_int (idx_of_const sym)
+ ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
+ fun constr_fun_ident c =
+ if Config.get ctxt trace
+ then string_of_int (idx_of_const c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)"
+ else string_of_int (idx_of_const c);
+
+ fun apply_local i sym = nbe_apps_local (fun_ident i sym);
+ fun apply_constr sym = nbe_apps_constr (constr_fun_ident sym);
+
fun assemble_constapp sym dss ts =
let
val ts' = (maps o map) assemble_dict dss @ ts;
in case AList.lookup (op =) eqnss' sym
of SOME (num_args, _) => if num_args <= length ts'
then let val (ts1, ts2) = chop num_args ts'
- in nbe_apps (nbe_apps_local idx_of 0 sym ts1) ts2
- end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 sym)) ts'
+ in nbe_apps (apply_local 0 sym ts1) ts2
+ end else nbe_apps (nbe_abss num_args (fun_ident 0 sym)) ts'
| NONE => if member (op =) deps sym
- then nbe_apps (nbe_fun idx_of 0 sym) ts'
- else nbe_apps_constr ctxt idx_of sym ts'
+ then nbe_apps (fun_ident 0 sym) ts'
+ else apply_constr sym ts'
end
and assemble_classrels classrels =
fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] o single) classrels
@@ -359,9 +362,9 @@
fun assemble_eqn sym dicts default_args (i, (args, rhs)) =
let
val match_cont = if Code_Symbol.is_value sym then NONE
- else SOME (nbe_apps_local idx_of (i + 1) sym (dicts @ default_args));
+ else SOME (apply_local (i + 1) sym (dicts @ default_args));
val assemble_arg = assemble_iterm
- (fn sym' => fn dss => fn ts => nbe_apps_constr ctxt idx_of sym' ((maps o map) (K "_")
+ (fn sym' => fn dss => fn ts => apply_constr sym' ((maps o map) (K "_")
dss @ ts)) NONE;
val assemble_rhs = assemble_iterm assemble_constapp match_cont;
val (samepairs, args') = subst_nonlin_vars args;
@@ -374,19 +377,19 @@
| SOME default_rhs =>
[([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
([ml_list (rev (dicts @ default_args))], default_rhs)]
- in (nbe_fun idx_of i sym, eqns) end;
+ in (fun_ident i sym, eqns) end;
fun assemble_eqns (sym, (num_args, (dicts, eqns))) =
let
val default_args = map nbe_default (Name.invent_global "a" (num_args - length dicts));
val eqns' = map_index (assemble_eqn sym dicts default_args) eqns
- @ (if Code_Symbol.is_value sym then [] else [(nbe_fun idx_of (length eqns) sym,
+ @ (if Code_Symbol.is_value sym then [] else [(fun_ident (length eqns) sym,
[([ml_list (rev (dicts @ default_args))],
- nbe_apps_constr ctxt idx_of sym (dicts @ default_args))])]);
- in (eqns', nbe_abss num_args (nbe_fun idx_of 0 sym)) end;
+ apply_constr sym (dicts @ default_args))])]);
+ in (eqns', nbe_abss num_args (fun_ident 0 sym)) end;
val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
- val deps_vars = ml_list (map (nbe_fun idx_of 0) deps);
+ val deps_vars = ml_list (map (fun_ident 0) deps);
in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
@@ -397,19 +400,19 @@
let
val (deps, deps_vals) = split_list (map_filter
(fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps);
- val idx_of = raw_deps
+ val idx_of_const = raw_deps
|> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep)))
|> AList.lookup (op =)
|> (fn f => the o f);
- val s = assemble_eqnss ctxt idx_of deps eqnss;
- val cs = map fst eqnss;
+ val s = assemble_eqnss ctxt idx_of_const deps eqnss;
+ val syms = map fst eqnss;
in
s
|> traced ctxt (fn s => "\n--- code to be evaluated:\n" ^ s)
|> pair ""
|> Code_Runtime.value ctxt univs_cookie
|> (fn f => f deps_vals)
- |> (fn univs => cs ~~ univs)
+ |> (fn univs => syms ~~ univs)
end;
@@ -450,11 +453,11 @@
(* compilation of whole programs *)
-fun ensure_const_idx name (nbe_program, (maxidx, idx_tab)) =
+fun ensure_const_idx name (nbe_program, (maxidx, const_tab)) =
if can (Code_Symbol.Graph.get_node nbe_program) name
- then (nbe_program, (maxidx, idx_tab))
+ then (nbe_program, (maxidx, const_tab))
else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program,
- (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
+ (maxidx + 1, Inttab.update_new (maxidx, name) const_tab));
fun compile_stmts ctxt stmts_deps =
let
@@ -472,16 +475,16 @@
fold ensure_const_idx refl_deps
#> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
#> compile
- #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
+ #-> fold (fn (sym, univ) => (Code_Symbol.Graph.map_node sym o apfst) (K (SOME univ))))
end;
fun compile_program { ctxt, program } =
let
- fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
- then (nbe_program, (maxidx, idx_tab))
- else (nbe_program, (maxidx, idx_tab))
- |> compile_stmts ctxt (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
- Code_Symbol.Graph.immediate_succs program name)) names);
+ fun add_stmts names (nbe_program, (maxidx, const_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
+ then (nbe_program, (maxidx, const_tab))
+ else (nbe_program, (maxidx, const_tab))
+ |> compile_stmts ctxt (map (fn sym => ((sym, Code_Symbol.Graph.get_node program sym),
+ Code_Symbol.Graph.immediate_succs program sym)) names);
in
fold_rev add_stmts (Code_Symbol.Graph.strong_conn program)
end;
@@ -501,18 +504,18 @@
|> (fn t => apps t (rev dict_frees))
end;
-fun reconstruct_term ctxt (idx_tab : Code_Symbol.T Inttab.table) t =
+fun reconstruct_term ctxt (const_tab : Code_Symbol.T Inttab.table) t =
let
fun take_until f [] = []
| take_until f (x :: xs) = if f x then [] else x :: take_until f xs;
fun is_dict (Const (idx, _)) =
- (case Inttab.lookup idx_tab idx of
+ (case Inttab.lookup const_tab idx of
SOME (Constant _) => false
| _ => true)
| is_dict (DFree _) = true
| is_dict _ = false;
fun const_of_idx idx =
- case Inttab.lookup idx_tab idx of SOME (Constant const) => const;
+ case Inttab.lookup const_tab idx of SOME (Constant const) => const;
fun of_apps bounds (t, ts) =
fold_map (of_univ bounds) ts
#>> (fn ts' => list_comb (t, rev ts'))
@@ -533,12 +536,12 @@
|-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
in of_univ 0 t 0 |> fst end;
-fun compile_and_reconstruct_term { ctxt, nbe_program, idx_tab, deps, term } =
+fun compile_and_reconstruct_term { ctxt, nbe_program, const_tab, deps, term } =
compile_term
{ ctxt = ctxt, nbe_program = nbe_program, deps = deps, term = term }
- |> reconstruct_term ctxt idx_tab;
+ |> reconstruct_term ctxt const_tab;
-fun normalize_term (nbe_program, idx_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps =
+fun normalize_term (nbe_program, const_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps =
let
val ctxt = Syntax.init_pretty_global (Proof_Context.theory_of raw_ctxt);
val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
@@ -553,7 +556,7 @@
else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t');
in
Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term
- { ctxt = ctxt, nbe_program = nbe_program, idx_tab = idx_tab, deps = deps, term = (vs, t) }
+ { ctxt = ctxt, nbe_program = nbe_program, const_tab = const_tab, deps = deps, term = (vs, t) }
|> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t)
|> type_infer
|> traced ctxt (fn t => "Types inferred:\n" ^ string_of_term t)
@@ -572,11 +575,11 @@
fun compile ignore_cache ctxt program =
let
- val (nbe_program, (_, idx_tab)) =
+ val (nbe_program, (_, const_tab)) =
Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt))
(Code_Preproc.timed "compiling NBE program" #ctxt
compile_program { ctxt = ctxt, program = program });
- in (nbe_program, idx_tab) end;
+ in (nbe_program, const_tab) end;
(* evaluation oracle *)
@@ -590,11 +593,11 @@
val (_, raw_oracle) =
Theory.setup_result (Thm.add_oracle (\<^binding>\<open>normalization_by_evaluation\<close>,
- fn (nbe_program_idx_tab, ctxt, vs_ty_t, deps, ct) =>
- mk_equals ctxt ct (normalize_term nbe_program_idx_tab ctxt (Thm.term_of ct) vs_ty_t deps)));
+ fn (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct) =>
+ mk_equals ctxt ct (normalize_term nbe_program_const_tab ctxt (Thm.term_of ct) vs_ty_t deps)));
-fun oracle nbe_program_idx_tab ctxt vs_ty_t deps ct =
- raw_oracle (nbe_program_idx_tab, ctxt, vs_ty_t, deps, ct);
+fun oracle nbe_program_const_tab ctxt vs_ty_t deps ct =
+ raw_oracle (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct);
fun dynamic_conv ctxt = lift_triv_classes_conv ctxt
(fn ctxt' => Code_Thingol.dynamic_conv ctxt' (fn program =>