--- a/src/Tools/nbe.ML Mon Apr 14 20:19:05 2025 +0200
+++ b/src/Tools/nbe.ML Mon Apr 14 20:19:05 2025 +0200
@@ -221,19 +221,20 @@
fun ml_cases e cs = enclose "(" ")"
("case " ^ e ^ " of " ^ space_implode " | " (map (fn (p, e) => p ^ " => " ^ e) cs));
-fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
-fun ml_as v t = enclose "(" ")" (v ^ " as " ^ t);
+fun ml_let d e = "\n let\n" ^ prefix_lines " " d ^ "\n in " ^ e ^ " end";
fun ml_and [] = "true"
| ml_and [e] = e
| ml_and es = enclose "(" ")" (space_implode " andalso " es);
-fun ml_if b e1 e2 = enclose "(" ")" ("if" ^ b ^ " then " ^ e1 ^ " else " ^ e2);
+fun ml_if b e1 e2 = enclose "(" ")" (implode_space ["if", b, "then", e1, "else", e2]);
fun e1 `*` e2 = enclose "(" ")" (e1 ^ ", " ^ e2);
fun ml_list es = enclose "[" "]" (commas es);
+fun ml_exc s = enclose "(" ")" ("raise Fail " ^ quote s);
+
fun ml_fundefs ([(name, [([], e)])]) =
- "val " ^ name ^ " = " ^ e ^ "\n"
+ "val " ^ name ^ " = " ^ e ^ ""
| ml_fundefs (eqs :: eqss) =
let
fun fundef (name, eqs) =
@@ -243,7 +244,6 @@
in
(prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
|> cat_lines
- |> suffix "\n"
end;
@@ -278,17 +278,18 @@
fun nbe_tparam v = "t_" ^ v;
fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
+fun nbe_global_param v = "w_" ^ v;
fun nbe_bound v = "v_" ^ v;
fun nbe_bound_optional NONE = "_"
| nbe_bound_optional (SOME v) = nbe_bound v;
-fun nbe_default v = "w_" ^ v;
fun nbe_type a_sym a_tys = name_type `$` (quote a_sym `*` ml_list a_tys);
fun nbe_fun a_sym a_typargs = a_sym `$` ml_list a_typargs;
-(*note: these three are the "turning spots" where proper argument order is established!*)
+(*note: these are the "turning spots" where proper argument order is established!*)
fun nbe_apps a_u [] = a_u
| nbe_apps a_u a_us = name_apps `$$` [a_u, ml_list (rev a_us)];
fun nbe_apps_fun a_sym a_typargs a_us = nbe_fun a_sym a_typargs `$` ml_list (rev a_us);
+fun nbe_apps_seq_fun a_sym a_us = a_sym `$$` a_us;
fun nbe_apps_constr a_sym a_typargs a_us = name_const `$` ((a_sym `*` ml_list a_typargs) `*` ml_list (rev a_us));
fun nbe_apps_constmatch a_sym a_us = name_const `$` ((a_sym `*` "_") `*` ml_list (rev a_us));
@@ -332,24 +333,26 @@
val a_tparams = map (fn (v, _) => nbe_tparam v) vs;
val a_dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
val num_args = length a_dict_params + ((length o fst o hd) eqns);
- val a_default_params = map nbe_default (Name.invent_global "a" (num_args - length a_dict_params));
- in (sym, (num_args, (a_tparams, a_dict_params, (map o apfst) subst_nonlin_vars eqns, a_default_params))) end;
+ val a_global_params = map nbe_global_param (Name.invent_global "a" (num_args - length a_dict_params));
+ in (sym, (num_args, (a_tparams, a_dict_params, a_global_params, (map o apfst) subst_nonlin_vars eqns))) end;
fun assemble_type (tyco `%% tys) = nbe_type tyco (map assemble_type tys)
| assemble_type (ITyVar v) = nbe_tparam v
fun assemble_preprocessed_eqnss ctxt idx_of_sym deps eqnss =
let
- fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value"
- | fun_ident i sym = "c_" ^ string_of_int (idx_of_sym sym)
- ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
+ fun fun_ident sym = space_implode "_"
+ ["c", if Code_Symbol.is_value sym then "value" else string_of_int (idx_of_sym sym), Code_Symbol.default_base sym, "nbe"];
+ fun seq_fun_ident i sym = space_implode "_"
+ ["c", string_of_int (idx_of_sym sym), Code_Symbol.default_base sym, string_of_int i];
fun constr_ident sym =
if Config.get ctxt trace
then string_of_int (idx_of_sym sym) ^ " (*" ^ Code_Symbol.default_base sym ^ "*)"
else string_of_int (idx_of_sym sym);
- fun assemble_fun i sym = nbe_fun (fun_ident i sym);
- fun assemble_app_fun i sym = nbe_apps_fun (fun_ident i sym);
+ fun assemble_fun sym = nbe_fun (fun_ident sym);
+ fun assemble_app_fun sym = nbe_apps_fun (fun_ident sym);
+ fun assemble_app_seq_fun i sym = nbe_apps_seq_fun (seq_fun_ident i sym);
fun assemble_app_constr sym = nbe_apps_constr (constr_ident sym);
fun assemble_app_constmatch sym = nbe_apps_constmatch (constr_ident sym);
@@ -360,10 +363,10 @@
in case AList.lookup (op =) eqnss sym
of SOME (num_args, _) => if num_args <= length a_ts'
then let val (a_ts1, a_ts2) = chop num_args a_ts'
- in nbe_apps (assemble_app_fun 0 sym a_typargs a_ts1) a_ts2
- end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym a_typargs)) a_ts'
+ in nbe_apps (assemble_app_fun sym a_typargs a_ts1) a_ts2
+ end else nbe_apps (nbe_abss num_args (assemble_fun sym a_typargs)) a_ts'
| NONE => if member (op =) deps sym
- then nbe_apps (assemble_fun 0 sym a_typargs) a_ts'
+ then nbe_apps (assemble_fun sym a_typargs) a_ts'
else assemble_app_constr sym a_typargs a_ts'
end
and assemble_classrels classrels =
@@ -397,35 +400,42 @@
val assemble_args = map (assemble_iterm assemble_constmatch NONE);
val assemble_rhs = assemble_iterm assemble_constapp;
- fun assemble_eqn sym a_tparams a_dict_params a_default_params (i, ((samepairs, args), rhs)) =
+ fun assemble_eqn sym a_global_params (i, ((samepairs, args), rhs)) =
let
- val default_rhs = assemble_app_fun (i + 1) sym a_tparams (a_dict_params @ a_default_params);
- val s_args = assemble_args args;
- val s_rhs = if null samepairs then assemble_rhs (SOME default_rhs) rhs
+ val a_fallback_rhs = assemble_app_seq_fun (i + 1) sym a_global_params;
+ val a_args = assemble_args args;
+ val a_rhs = if null samepairs then assemble_rhs (SOME a_fallback_rhs) rhs
else ml_if (ml_and (map nbe_same samepairs))
- (assemble_rhs (SOME default_rhs) rhs) default_rhs;
- val eqns = [([ml_list a_tparams, ml_list (rev (a_dict_params @ map2 ml_as a_default_params s_args))], s_rhs),
- ([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))], default_rhs)]
- in (fun_ident i sym, eqns) end;
+ (assemble_rhs (SOME a_fallback_rhs) rhs) a_fallback_rhs;
+ val fallback_eqn = if forall Code_Thingol.is_IVar args then []
+ else [(replicate (length a_global_params) "_", a_fallback_rhs)];
+ in
+ (seq_fun_ident i sym, (a_args, a_rhs) :: fallback_eqn)
+ end;
+
+ fun assemble_default_eqn sym a_tparams a_dict_params a_global_params i =
+ (seq_fun_ident i sym,
+ [(replicate (length a_global_params) "_", assemble_app_constr sym a_tparams (a_dict_params @ a_global_params))])
- fun assemble_default_eqn sym a_tparams a_dict_params a_default_params i =
- (fun_ident i sym,
- [([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))],
- assemble_app_constr sym a_tparams (a_dict_params @ a_default_params))])
-
- fun assemble_value_eqn sym a_tparams a_dict_params (([], args), rhs) =
- (fun_ident 0 sym,
- [([ml_list a_tparams, ml_list (rev (a_dict_params @ assemble_args args))], assemble_rhs NONE rhs)]);
+ fun assemble_seq_eqns sym a_tparams a_dict_params a_global_params [(([], []), rhs)] =
+ assemble_rhs NONE rhs
+ | assemble_seq_eqns sym a_tparams a_dict_params a_global_params eqns =
+ ml_let (ml_fundefs (map_index (assemble_eqn sym a_global_params) eqns
+ @ [assemble_default_eqn sym a_tparams a_dict_params a_global_params (length eqns)]))
+ (assemble_app_seq_fun 0 sym a_global_params);
- fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, eqns, default_params))) =
- (if Code_Symbol.is_value sym then [assemble_value_eqn sym a_tparams a_dict_params (the_single eqns)]
- else map_index (assemble_eqn sym a_tparams a_dict_params default_params) eqns
- @ [assemble_default_eqn sym a_tparams a_dict_params default_params (length eqns)],
- ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun 0 sym a_tparams)));
+ fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, a_global_params, eqns))) =
+ let
+ val a_lhs = [ml_list a_tparams, ml_list (rev (a_dict_params @ a_global_params))];
+ val a_rhs = assemble_seq_eqns sym a_tparams a_dict_params a_global_params eqns;
+ val a_univ = ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun sym a_tparams));
+ in
+ ((fun_ident sym, [(a_lhs, a_rhs), (a_lhs, ml_exc (fun_ident sym))]), a_univ)
+ end;
- val (fun_vars, fun_vals) = map_split assemble_eqns eqnss;
- val deps_vars = ml_list (map (fun_ident 0) deps);
- in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
+ val (a_fun_defs, a_fun_vals) = map_split assemble_eqns eqnss;
+ val dep_params = ml_list (map fun_ident deps);
+ in ml_abs dep_params (ml_let (ml_fundefs a_fun_defs) (ml_list a_fun_vals)) end;
fun assemble_eqnss ctxt idx_of_sym deps eqnss =
assemble_preprocessed_eqnss ctxt idx_of_sym deps (map preprocess_eqns eqnss);