--- a/src/Tools/nbe.ML Sun Apr 06 18:12:52 2025 +0200
+++ b/src/Tools/nbe.ML Sun Apr 06 18:12:53 2025 +0200
@@ -184,28 +184,27 @@
| Abs of (int * (Univ list -> Univ)) * Univ list
(*abstractions as closures*)
-
(* constructor functions *)
fun abss n f = Abs ((n, f), []);
-fun apps (Abs ((n, f), xs)) ys = let val k = n - length ys in
+fun apps (Abs ((n, f), us)) ws = let val k = n - length ws in
case int_ord (k, 0)
- of EQUAL => f (ys @ xs)
- | LESS => let val (zs, ws) = chop (~ k) ys in apps (f (ws @ xs)) zs end
- | GREATER => Abs ((k, f), ys @ xs) (*note: reverse convention also for apps!*)
+ of EQUAL => f (ws @ us)
+ | LESS => let val (ws1, ws2) = chop (~ k) ws in apps (f (ws2 @ us)) ws1 end
+ | GREATER => Abs ((k, f), ws @ us) (*note: reverse convention also for apps!*)
end
| apps (Const (name, xs)) ys = Const (name, ys @ xs)
| apps (BVar (n, xs)) ys = BVar (n, ys @ xs);
-fun same_type (Type (tyco1, types1), Type (tyco2, types2)) =
- (tyco1 = tyco2) andalso eq_list same_type (types1, types2)
+fun same_type (Type (tyco1, ty1), Type (tyco2, tys2)) =
+ (tyco1 = tyco2) andalso eq_list same_type (ty1, tys2)
| same_type (TParam v1, TParam v2) = (v1 = v2)
| same_type _ = false;
-fun same (Const ((k1, ts1), xs1), Const ((k2, ts2), xs2)) =
- (k1 = k2) andalso eq_list same_type (ts1, ts2) andalso eq_list same (xs1, xs2)
+fun same (Const ((k1, typargs1), us1), Const ((k2, typargs2), us2)) =
+ (k1 = k2) andalso eq_list same_type (typargs1, typargs2) andalso eq_list same (us1, us2)
| same (DFree (n1, i1), DFree (n2, i2)) = (n1 = n2) andalso (i1 = i2)
- | same (BVar (i1, xs1), BVar (i2, xs2)) = (i1 = i2) andalso eq_list same (xs1, xs2)
+ | same (BVar (i1, us1), BVar (i2, us2)) = (i1 = i2) andalso eq_list same (us1, us2)
| same _ = false;
@@ -271,21 +270,27 @@
val univs_cookie = (get_result, put_result, name_put);
+(*
+ Convention: parameters representing ("assembled") string representations of logical entities
+ are prefixed with an "a_" -- unless they are an unqualified name ready to become part of
+ an ML identifier.
+*)
+
fun nbe_tparam v = "t_" ^ v;
fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
fun nbe_bound v = "v_" ^ v;
fun nbe_bound_optional NONE = "_"
| nbe_bound_optional (SOME v) = nbe_bound v;
fun nbe_default v = "w_" ^ v;
-fun nbe_type n ts = name_type `$` (quote n `*` ml_list ts);
-fun nbe_fun c tys = c `$` ml_list tys;
+fun nbe_type a_sym a_tys = name_type `$` (quote a_sym `*` ml_list a_tys);
+fun nbe_fun a_sym a_typargs = a_sym `$` ml_list a_typargs;
(*note: these three are the "turning spots" where proper argument order is established!*)
-fun nbe_apps t [] = t
- | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
-fun nbe_apps_fun c tys ts = nbe_fun c tys `$` ml_list (rev ts);
-fun nbe_apps_constr c tys ts = name_const `$` ((c `*` ml_list tys) `*` ml_list (rev ts));
-fun nbe_apps_constmatch c ts = name_const `$` ((c `*` "_") `*` ml_list (rev ts));
+fun nbe_apps a_u [] = a_u
+ | nbe_apps a_u a_us = name_apps `$$` [a_u, ml_list (rev a_us)];
+fun nbe_apps_fun a_sym a_typargs a_us = nbe_fun a_sym a_typargs `$` ml_list (rev a_us);
+fun nbe_apps_constr a_sym a_typargs a_us = name_const `$` ((a_sym `*` ml_list a_typargs) `*` ml_list (rev a_us));
+fun nbe_apps_constmatch a_sym a_us = name_const `$` ((a_sym `*` "_") `*` ml_list (rev a_us));
fun nbe_abss 0 f = f `$` ml_list []
| nbe_abss n f = name_abss `$$` [string_of_int n, f];
@@ -324,54 +329,54 @@
fun preprocess_eqns (sym, (vs, eqns)) =
let
- val s_tparams = map (fn (v, _) => nbe_tparam v) vs;
- val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
- val num_args = length dict_params + ((length o fst o hd) eqns);
- val default_params = map nbe_default (Name.invent_global "a" (num_args - length dict_params));
- in (sym, (num_args, (s_tparams, dict_params, (map o apfst) subst_nonlin_vars eqns, default_params))) end;
+ val a_tparams = map (fn (v, _) => nbe_tparam v) vs;
+ val a_dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
+ val num_args = length a_dict_params + ((length o fst o hd) eqns);
+ val a_default_params = map nbe_default (Name.invent_global "a" (num_args - length a_dict_params));
+ in (sym, (num_args, (a_tparams, a_dict_params, (map o apfst) subst_nonlin_vars eqns, a_default_params))) end;
fun assemble_type (tyco `%% tys) = nbe_type tyco (map assemble_type tys)
| assemble_type (ITyVar v) = nbe_tparam v
-fun assemble_preprocessed_eqnss ctxt idx_of_const deps eqnss =
+fun assemble_preprocessed_eqnss ctxt idx_of_sym deps eqnss =
let
fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value"
- | fun_ident i sym = "c_" ^ string_of_int (idx_of_const sym)
+ | fun_ident i sym = "c_" ^ string_of_int (idx_of_sym sym)
^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
- fun constr_ident c =
+ fun constr_ident sym =
if Config.get ctxt trace
- then string_of_int (idx_of_const c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)"
- else string_of_int (idx_of_const c);
+ then string_of_int (idx_of_sym sym) ^ " (*" ^ Code_Symbol.default_base sym ^ "*)"
+ else string_of_int (idx_of_sym sym);
fun assemble_fun i sym = nbe_fun (fun_ident i sym);
fun assemble_app_fun i sym = nbe_apps_fun (fun_ident i sym);
fun assemble_app_constr sym = nbe_apps_constr (constr_ident sym);
fun assemble_app_constmatch sym = nbe_apps_constmatch (constr_ident sym);
- fun assemble_constapp sym tys dictss ts =
+ fun assemble_constapp sym typargs dictss a_ts =
let
- val s_tys = map (assemble_type) tys;
- val ts' = (maps o map) assemble_dict (map2 (fn ty => map (fn dict => (ty, dict))) tys dictss) @ ts;
+ val a_typargs = map (assemble_type) typargs;
+ val a_ts' = (maps o map) assemble_dict (map2 (fn ty => map (fn dict => (ty, dict))) typargs dictss) @ a_ts;
in case AList.lookup (op =) eqnss sym
- of SOME (num_args, _) => if num_args <= length ts'
- then let val (ts1, ts2) = chop num_args ts'
- in nbe_apps (assemble_app_fun 0 sym s_tys ts1) ts2
- end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym s_tys)) ts'
+ of SOME (num_args, _) => if num_args <= length a_ts'
+ then let val (a_ts1, a_ts2) = chop num_args a_ts'
+ in nbe_apps (assemble_app_fun 0 sym a_typargs a_ts1) a_ts2
+ end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym a_typargs)) a_ts'
| NONE => if member (op =) deps sym
- then nbe_apps (assemble_fun 0 sym s_tys) ts'
- else assemble_app_constr sym s_tys ts'
+ then nbe_apps (assemble_fun 0 sym a_typargs) a_ts'
+ else assemble_app_constr sym a_typargs a_ts'
end
and assemble_classrels classrels =
fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] [] o single) classrels
- and assemble_dict (ty, Dict (classrels, x)) =
- assemble_classrels classrels (assemble_plain_dict ty x)
+ and assemble_dict (ty, Dict (classrels, dict)) =
+ assemble_classrels classrels (assemble_plain_dict ty dict)
and assemble_plain_dict (_ `%% tys) (Dict_Const (inst, dictss)) =
assemble_constapp (Class_Instance inst) tys (map snd dictss) []
| assemble_plain_dict _ (Dict_Var { var, index, ... }) =
nbe_dict var index
- fun assemble_constmatch sym _ dictss ts =
- assemble_app_constmatch sym ((maps o map) (K "_") dictss @ ts);
+ fun assemble_constmatch sym _ dictss a_ts =
+ assemble_app_constmatch sym ((maps o map) (K "_") dictss @ a_ts);
fun assemble_iterm constapp =
let
@@ -379,50 +384,51 @@
let
val (t', ts) = Code_Thingol.unfold_app t
in of_iapp match_continuation t' (fold_rev (cons o of_iterm NONE) ts []) end
- and of_iapp match_continuation (IConst { sym, typargs = tys, dictss, ... }) ts = constapp sym tys dictss ts
- | of_iapp match_continuation (IVar v) ts = nbe_apps (nbe_bound_optional v) ts
- | of_iapp match_continuation ((v, _) `|=> (t, _)) ts =
- nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) ts
- | of_iapp match_continuation (ICase { term = t, clauses = clauses, primitive = t0, ... }) ts =
+ and of_iapp match_continuation (IConst { sym, typargs, dictss, ... }) = constapp sym typargs dictss
+ | of_iapp match_continuation (IVar v) = nbe_apps (nbe_bound_optional v)
+ | of_iapp match_continuation ((v, _) `|=> (t, _)) =
+ nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t)))
+ | of_iapp match_continuation (ICase { term = t, clauses = clauses, primitive = t0, ... }) =
nbe_apps (ml_cases (of_iterm NONE t)
(map (fn (p, t) => (assemble_iterm assemble_constmatch NONE p, of_iterm match_continuation t)) clauses
- @ [("_", case match_continuation of SOME s => s | NONE => of_iterm NONE t0)])) ts
+ @ [("_", case match_continuation of SOME s => s | NONE => of_iterm NONE t0)]))
in of_iterm end;
val assemble_args = map (assemble_iterm assemble_constmatch NONE);
val assemble_rhs = assemble_iterm assemble_constapp;
- fun assemble_eqn sym s_tparams dict_params default_params (i, ((samepairs, args), rhs)) =
+ fun assemble_eqn sym a_tparams a_dict_params a_default_params (i, ((samepairs, args), rhs)) =
let
- val default_rhs = assemble_app_fun (i + 1) sym s_tparams (dict_params @ default_params);
+ val default_rhs = assemble_app_fun (i + 1) sym a_tparams (a_dict_params @ a_default_params);
val s_args = assemble_args args;
val s_rhs = if null samepairs then assemble_rhs (SOME default_rhs) rhs
else ml_if (ml_and (map nbe_same samepairs))
(assemble_rhs (SOME default_rhs) rhs) default_rhs;
- val eqns = [([ml_list s_tparams, ml_list (rev (dict_params @ map2 ml_as default_params s_args))], s_rhs),
- ([ml_list s_tparams, ml_list (rev (dict_params @ default_params))], default_rhs)]
+ val eqns = [([ml_list a_tparams, ml_list (rev (a_dict_params @ map2 ml_as a_default_params s_args))], s_rhs),
+ ([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))], default_rhs)]
in (fun_ident i sym, eqns) end;
- fun assemble_default_eqn sym s_tparams dict_params default_params i =
+ fun assemble_default_eqn sym a_tparams a_dict_params a_default_params i =
(fun_ident i sym,
- [([ml_list s_tparams, ml_list (rev (dict_params @ default_params))], assemble_app_constr sym s_tparams (dict_params @ default_params))])
+ [([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))],
+ assemble_app_constr sym a_tparams (a_dict_params @ a_default_params))])
- fun assemble_value_eqn sym s_tparams dict_params (([], args), rhs) =
+ fun assemble_value_eqn sym a_tparams a_dict_params (([], args), rhs) =
(fun_ident 0 sym,
- [([ml_list s_tparams, ml_list (rev (dict_params @ assemble_args args))], assemble_rhs NONE rhs)]);
+ [([ml_list a_tparams, ml_list (rev (a_dict_params @ assemble_args args))], assemble_rhs NONE rhs)]);
- fun assemble_eqns (sym, (num_args, (s_tparams, dict_params, eqns, default_params))) =
- (if Code_Symbol.is_value sym then [assemble_value_eqn sym s_tparams dict_params (the_single eqns)]
- else map_index (assemble_eqn sym s_tparams dict_params default_params) eqns
- @ [assemble_default_eqn sym s_tparams dict_params default_params (length eqns)],
- ml_abs (ml_list s_tparams) (nbe_abss num_args (assemble_fun 0 sym s_tparams)));
+ fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, eqns, default_params))) =
+ (if Code_Symbol.is_value sym then [assemble_value_eqn sym a_tparams a_dict_params (the_single eqns)]
+ else map_index (assemble_eqn sym a_tparams a_dict_params default_params) eqns
+ @ [assemble_default_eqn sym a_tparams a_dict_params default_params (length eqns)],
+ ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun 0 sym a_tparams)));
val (fun_vars, fun_vals) = map_split assemble_eqns eqnss;
val deps_vars = ml_list (map (fun_ident 0) deps);
in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
-fun assemble_eqnss ctxt idx_of_const deps eqnss =
- assemble_preprocessed_eqnss ctxt idx_of_const deps (map preprocess_eqns eqnss);
+fun assemble_eqnss ctxt idx_of_sym deps eqnss =
+ assemble_preprocessed_eqnss ctxt idx_of_sym deps (map preprocess_eqns eqnss);
(* compilation of equations *)
@@ -432,26 +438,26 @@
let
val (deps, deps_vals) = split_list (map_filter
(fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps);
- val idx_of_const = raw_deps
+ val idx_of_sym = raw_deps
|> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep)))
|> AList.lookup (op =)
|> (fn f => the o f);
- val s = assemble_eqnss ctxt idx_of_const deps eqnss;
+ val s = assemble_eqnss ctxt idx_of_sym deps eqnss;
val syms = map fst eqnss;
in
s
|> traced ctxt (fn s => "\n--- code to be evaluated:\n" ^ s)
|> pair ""
|> Code_Runtime.value ctxt univs_cookie
- |> (fn f => f deps_vals)
- |> (fn poly_univs => syms ~~ poly_univs)
+ |> (fn dependent_fs => dependent_fs deps_vals)
+ |> (fn fs => syms ~~ fs)
end;
(* extraction of equations from statements *)
-fun dummy_const sym tys dictss =
- IConst { sym = sym, typargs = tys, dictss = dictss,
+fun dummy_const sym typargs dictss =
+ IConst { sym = sym, typargs = typargs, dictss = dictss,
dom = [], annotation = NONE, range = ITyVar "" };
fun eqns_of_stmt (_, Code_Thingol.NoStmt) =
@@ -486,11 +492,11 @@
(* compilation of whole programs *)
-fun ensure_const_idx name (nbe_program, (maxidx, const_tab)) =
- if can (Code_Symbol.Graph.get_node nbe_program) name
- then (nbe_program, (maxidx, const_tab))
- else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program,
- (maxidx + 1, Inttab.update_new (maxidx, name) const_tab));
+fun ensure_sym_idx sym (nbe_program, (max_idx, sym_tab)) =
+ if can (Code_Symbol.Graph.get_node nbe_program) sym
+ then (nbe_program, (max_idx, sym_tab))
+ else (Code_Symbol.Graph.new_node (sym, (NONE, max_idx)) nbe_program,
+ (max_idx + 1, Inttab.update_new (max_idx, sym) sym_tab));
fun compile_stmts ctxt stmts_deps =
let
@@ -505,7 +511,7 @@
|> compile_eqnss ctxt nbe_program refl_deps
|> rpair nbe_program;
in
- fold ensure_const_idx refl_deps
+ fold ensure_sym_idx refl_deps
#> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
#> compile
#-> fold (fn (sym, univ) => (Code_Symbol.Graph.map_node sym o apfst) (K (SOME univ))))
@@ -513,10 +519,10 @@
fun compile_program { ctxt, program } =
let
- fun add_stmts names (nbe_program, (maxidx, const_tab)) =
+ fun add_stmts names (nbe_program, (max_idx, sym_tab)) =
if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
- then (nbe_program, (maxidx, const_tab))
- else (nbe_program, (maxidx, const_tab))
+ then (nbe_program, (max_idx, sym_tab))
+ else (nbe_program, (max_idx, sym_tab))
|> compile_stmts ctxt (map (fn sym => ((sym, Code_Symbol.Graph.get_node program sym),
Code_Symbol.Graph.immediate_succs program sym)) names);
in
@@ -542,37 +548,37 @@
|> (fn f => apps (f tparams) (rev dict_frees))
end;
-fun reconstruct_term ctxt const_tab tfrees t =
+fun reconstruct_term ctxt sym_tab tfrees =
let
fun take_until f [] = []
| take_until f (x :: xs) = if f x then [] else x :: take_until f xs;
fun is_dict (Const ((idx, _), _)) =
- (case Inttab.lookup const_tab idx of
+ (case Inttab.lookup sym_tab idx of
SOME (Constant _) => false
| _ => true)
| is_dict (DFree _) = true
| is_dict _ = false;
fun const_of_idx idx =
- case Inttab.lookup const_tab idx of SOME (Constant const) => const;
+ case Inttab.lookup sym_tab idx of SOME (Constant const) => const;
fun reconstruct_type (Type (tyco, tys)) = Term.Type (tyco, map reconstruct_type tys)
| reconstruct_type (TParam v) = TFree (v, the (AList.lookup (op =) tfrees v));
fun of_apps bounds (t, ts) =
list_comb (t, rev (map (of_univ bounds) ts))
- and of_univ bounds (Const ((idx, tys), ts)) =
+ and of_univ bounds (Const ((idx, tparams), us)) =
let
val const = const_of_idx idx;
- val ts' = take_until is_dict ts;
- val T = Consts.instance (Proof_Context.consts_of ctxt) (const, map reconstruct_type tys);
- in of_apps bounds (Term.Const (const, T), ts') end
- | of_univ bounds (BVar (n, ts)) =
- of_apps bounds (Bound (bounds - n - 1), ts)
- | of_univ bounds (t as Abs _) =
- Term.Abs ("u", dummyT, of_univ (bounds + 1) (apps t [BVar (bounds, [])]))
- in of_univ 0 t end;
+ val us' = take_until is_dict us;
+ val T = Consts.instance (Proof_Context.consts_of ctxt) (const, map reconstruct_type tparams);
+ in of_apps bounds (Term.Const (const, T), us') end
+ | of_univ bounds (BVar (i, us)) =
+ of_apps bounds (Bound (bounds - i - 1), us)
+ | of_univ bounds (u as Abs _) =
+ Term.Abs ("u", dummyT, of_univ (bounds + 1) (apps u [BVar (bounds, [])]))
+ in of_univ 0 end;
-fun compile_and_reconstruct_term { ctxt, nbe_program, const_tab, deps, tfrees, vs_ty_t } =
+fun compile_and_reconstruct_term { ctxt, nbe_program, sym_tab, deps, tfrees, vs_ty_t } =
compile_term { ctxt = ctxt, nbe_program = nbe_program, deps = deps, tfrees = tfrees, vs_ty_t = vs_ty_t }
- |> reconstruct_term ctxt const_tab tfrees;
+ |> reconstruct_term ctxt sym_tab tfrees;
fun retype_term ctxt t T =
let
@@ -585,7 +591,7 @@
singleton (Variable.export_terms ctxt' ctxt') (Syntax.check_term ctxt' (Type.constraint T t))
end;
-fun normalize_term (nbe_program, const_tab) raw_ctxt t_original vs_ty_t deps =
+fun normalize_term (nbe_program, sym_tab) raw_ctxt t_original vs_ty_t deps =
let
val T = fastype_of t_original;
val tfrees = Term.add_tfrees t_original [];
@@ -601,7 +607,7 @@
else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t');
in
Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term
- { ctxt = ctxt, nbe_program = nbe_program, const_tab = const_tab, deps = deps,
+ { ctxt = ctxt, nbe_program = nbe_program, sym_tab = sym_tab, deps = deps,
tfrees = tfrees, vs_ty_t = vs_ty_t }
|> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t)
|> retype
@@ -621,11 +627,11 @@
fun compile ignore_cache ctxt program =
let
- val (nbe_program, (_, const_tab)) =
+ val (nbe_program, (_, sym_tab)) =
Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt))
(Code_Preproc.timed "compiling NBE program" #ctxt
compile_program { ctxt = ctxt, program = program });
- in (nbe_program, const_tab) end;
+ in (nbe_program, sym_tab) end;
(* evaluation oracle *)
@@ -639,11 +645,11 @@
val (_, raw_oracle) =
Theory.setup_result (Thm.add_oracle (\<^binding>\<open>normalization_by_evaluation\<close>,
- fn (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct) =>
- mk_equals ctxt ct (normalize_term nbe_program_const_tab ctxt (Thm.term_of ct) vs_ty_t deps)));
+ fn (nbe_program_sym_tab, ctxt, vs_ty_t, deps, ct) =>
+ mk_equals ctxt ct (normalize_term nbe_program_sym_tab ctxt (Thm.term_of ct) vs_ty_t deps)));
-fun oracle nbe_program_const_tab ctxt vs_ty_t deps ct =
- raw_oracle (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct);
+fun oracle nbe_program_sym_tab ctxt vs_ty_t deps ct =
+ raw_oracle (nbe_program_sym_tab, ctxt, vs_ty_t, deps, ct);
fun dynamic_conv ctxt = lift_triv_classes_conv ctxt
(fn ctxt' => Code_Thingol.dynamic_conv ctxt' (fn program =>