--- a/src/Tools/nbe.ML Mon Apr 14 20:04:40 2025 +0200
+++ b/src/Tools/nbe.ML Mon Apr 14 20:42:03 2025 +0200
@@ -214,6 +214,7 @@
infix 9 `$` `$$`;
infix 8 `*`;
+
fun e1 `$` e2 = enclose "(" ")" (e1 ^ " " ^ e2);
fun e `$$` [] = e
| e `$$` es = enclose "(" ")" (e ^ " " ^ implode_space es);
@@ -221,19 +222,20 @@
fun ml_cases e cs = enclose "(" ")"
("case " ^ e ^ " of " ^ space_implode " | " (map (fn (p, e) => p ^ " => " ^ e) cs));
-fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
-fun ml_as v t = enclose "(" ")" (v ^ " as " ^ t);
+fun ml_let d e = "\n let\n" ^ prefix_lines " " d ^ "\n in " ^ e ^ " end";
fun ml_and [] = "true"
| ml_and [e] = e
| ml_and es = enclose "(" ")" (space_implode " andalso " es);
-fun ml_if b e1 e2 = enclose "(" ")" ("if" ^ b ^ " then " ^ e1 ^ " else " ^ e2);
+fun ml_if b e1 e2 = enclose "(" ")" (implode_space ["if", b, "then", e1, "else", e2]);
fun e1 `*` e2 = enclose "(" ")" (e1 ^ ", " ^ e2);
fun ml_list es = enclose "[" "]" (commas es);
+fun ml_exc s = enclose "(" ")" ("raise Fail " ^ quote s);
+
fun ml_fundefs ([(name, [([], e)])]) =
- "val " ^ name ^ " = " ^ e ^ "\n"
+ "val " ^ name ^ " = " ^ e ^ ""
| ml_fundefs (eqs :: eqss) =
let
fun fundef (name, eqs) =
@@ -243,7 +245,6 @@
in
(prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
|> cat_lines
- |> suffix "\n"
end;
@@ -278,19 +279,20 @@
fun nbe_tparam v = "t_" ^ v;
fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
+fun nbe_global_param v = "w_" ^ v;
fun nbe_bound v = "v_" ^ v;
fun nbe_bound_optional NONE = "_"
| nbe_bound_optional (SOME v) = nbe_bound v;
-fun nbe_default v = "w_" ^ v;
fun nbe_type a_sym a_tys = name_type `$` (quote a_sym `*` ml_list a_tys);
fun nbe_fun a_sym a_typargs = a_sym `$` ml_list a_typargs;
-(*note: these three are the "turning spots" where proper argument order is established!*)
+(*note: these are the "turning spots" where proper argument order is established!*)
fun nbe_apps a_u [] = a_u
| nbe_apps a_u a_us = name_apps `$$` [a_u, ml_list (rev a_us)];
fun nbe_apps_fun a_sym a_typargs a_us = nbe_fun a_sym a_typargs `$` ml_list (rev a_us);
-fun nbe_apps_constr a_sym a_typargs a_us = name_const `$` ((a_sym `*` ml_list a_typargs) `*` ml_list (rev a_us));
-fun nbe_apps_constmatch a_sym a_us = name_const `$` ((a_sym `*` "_") `*` ml_list (rev a_us));
+fun nbe_apps_fallback_fun a_sym a_us = a_sym `$$` a_us;
+fun nbe_apps_const a_sym a_typargs a_us = name_const `$` ((a_sym `*` ml_list a_typargs) `*` ml_list (rev a_us));
+fun nbe_apps_constpat a_sym a_us = name_const `$` ((a_sym `*` "_") `*` ml_list (rev a_us));
fun nbe_abss 0 f = f `$` ml_list []
| nbe_abss n f = name_abss `$$` [string_of_int n, f];
@@ -332,26 +334,29 @@
val a_tparams = map (fn (v, _) => nbe_tparam v) vs;
val a_dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
val num_args = length a_dict_params + ((length o fst o hd) eqns);
- val a_default_params = map nbe_default (Name.invent_global "a" (num_args - length a_dict_params));
- in (sym, (num_args, (a_tparams, a_dict_params, (map o apfst) subst_nonlin_vars eqns, a_default_params))) end;
+ val a_global_params = map nbe_global_param (Name.invent_global "a" (num_args - length a_dict_params));
+ in (sym, (num_args, (a_tparams, a_dict_params, a_global_params, (map o apfst) subst_nonlin_vars eqns))) end;
fun assemble_type (tyco `%% tys) = nbe_type tyco (map assemble_type tys)
| assemble_type (ITyVar v) = nbe_tparam v
fun assemble_preprocessed_eqnss ctxt idx_of_sym deps eqnss =
let
- fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value"
- | fun_ident i sym = "c_" ^ string_of_int (idx_of_sym sym)
- ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
- fun constr_ident sym =
+ fun suffixed_fun_ident suffix sym = space_implode "_"
+ ["c", if Code_Symbol.is_value sym then "0" else string_of_int (idx_of_sym sym),
+ Code_Symbol.default_base sym, suffix];
+ val fun_ident = suffixed_fun_ident "nbe";
+ fun fallback_fun_ident i = suffixed_fun_ident (string_of_int i);
+ fun const_ident sym =
if Config.get ctxt trace
then string_of_int (idx_of_sym sym) ^ " (*" ^ Code_Symbol.default_base sym ^ "*)"
else string_of_int (idx_of_sym sym);
- fun assemble_fun i sym = nbe_fun (fun_ident i sym);
- fun assemble_app_fun i sym = nbe_apps_fun (fun_ident i sym);
- fun assemble_app_constr sym = nbe_apps_constr (constr_ident sym);
- fun assemble_app_constmatch sym = nbe_apps_constmatch (constr_ident sym);
+ fun assemble_fun sym = nbe_fun (fun_ident sym);
+ fun assemble_app_fun sym = nbe_apps_fun (fun_ident sym);
+ fun assemble_app_fallback_fun i sym = nbe_apps_fallback_fun (fallback_fun_ident i sym);
+ fun assemble_app_const sym = nbe_apps_const (const_ident sym);
+ fun assemble_app_constpat sym = nbe_apps_constpat (const_ident sym);
fun assemble_constapp sym typargs dictss a_ts =
let
@@ -360,11 +365,11 @@
in case AList.lookup (op =) eqnss sym
of SOME (num_args, _) => if num_args <= length a_ts'
then let val (a_ts1, a_ts2) = chop num_args a_ts'
- in nbe_apps (assemble_app_fun 0 sym a_typargs a_ts1) a_ts2
- end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym a_typargs)) a_ts'
+ in nbe_apps (assemble_app_fun sym a_typargs a_ts1) a_ts2
+ end else nbe_apps (nbe_abss num_args (assemble_fun sym a_typargs)) a_ts'
| NONE => if member (op =) deps sym
- then nbe_apps (assemble_fun 0 sym a_typargs) a_ts'
- else assemble_app_constr sym a_typargs a_ts'
+ then nbe_apps (assemble_fun sym a_typargs) a_ts'
+ else assemble_app_const sym a_typargs a_ts'
end
and assemble_classrels classrels =
fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] [] o single) classrels
@@ -373,59 +378,59 @@
and assemble_plain_dict (_ `%% tys) (Dict_Const (inst, dictss)) =
assemble_constapp (Class_Instance inst) tys (map snd dictss) []
| assemble_plain_dict _ (Dict_Var { var, index, ... }) =
- nbe_dict var index
+ nbe_dict var index;
- fun assemble_constmatch sym _ dictss a_ts =
- assemble_app_constmatch sym ((maps o map) (K "_") dictss @ a_ts);
-
- fun assemble_iterm constapp =
+ fun assemble_iterm is_pat a_match_fallback t =
let
- fun of_iterm match_continuation t =
- let
- val (t', ts) = Code_Thingol.unfold_app t
- in of_iapp match_continuation t' (fold_rev (cons o of_iterm NONE) ts []) end
- and of_iapp match_continuation (IConst { sym, typargs, dictss, ... }) = constapp sym typargs dictss
- | of_iapp match_continuation (IVar v) = nbe_apps (nbe_bound_optional v)
- | of_iapp match_continuation ((v, _) `|=> (t, _)) =
- nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t)))
- | of_iapp match_continuation (ICase { term = t, clauses = clauses, primitive = t0, ... }) =
- nbe_apps (ml_cases (of_iterm NONE t)
- (map (fn (p, t) => (assemble_iterm assemble_constmatch NONE p, of_iterm match_continuation t)) clauses
- @ [("_", case match_continuation of SOME s => s | NONE => of_iterm NONE t0)]))
- in of_iterm end;
+ fun assemble_app (IConst { sym, typargs, dictss, ... }) =
+ if is_pat then fn a_ts => assemble_app_constpat sym ((maps o map) (K "_") dictss @ a_ts)
+ else assemble_constapp sym typargs dictss
+ | assemble_app (IVar v) = nbe_apps (nbe_bound_optional v)
+ | assemble_app ((v, _) `|=> (t, _)) =
+ nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (assemble_iterm is_pat NONE t)))
+ | assemble_app (ICase { term = t, clauses = clauses, primitive = t0, ... }) =
+ nbe_apps (ml_cases (assemble_iterm is_pat NONE t)
+ (map (fn (p, t) => (assemble_iterm true NONE p, assemble_iterm is_pat a_match_fallback t)) clauses
+ @ [("_", case a_match_fallback of SOME s => s | NONE => assemble_iterm is_pat NONE t0)]))
+ val (t', ts) = Code_Thingol.unfold_app t;
+ val a_ts = fold_rev (cons o assemble_iterm is_pat NONE) ts [];
+ in assemble_app t' a_ts end;
- val assemble_args = map (assemble_iterm assemble_constmatch NONE);
- val assemble_rhs = assemble_iterm assemble_constapp;
-
- fun assemble_eqn sym a_tparams a_dict_params a_default_params (i, ((samepairs, args), rhs)) =
+ fun assemble_fallback_fundef sym a_global_params ((samepairs, args), rhs) a_fallback_rhs =
let
- val default_rhs = assemble_app_fun (i + 1) sym a_tparams (a_dict_params @ a_default_params);
- val s_args = assemble_args args;
- val s_rhs = if null samepairs then assemble_rhs (SOME default_rhs) rhs
- else ml_if (ml_and (map nbe_same samepairs))
- (assemble_rhs (SOME default_rhs) rhs) default_rhs;
- val eqns = [([ml_list a_tparams, ml_list (rev (a_dict_params @ map2 ml_as a_default_params s_args))], s_rhs),
- ([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))], default_rhs)]
- in (fun_ident i sym, eqns) end;
-
- fun assemble_default_eqn sym a_tparams a_dict_params a_default_params i =
- (fun_ident i sym,
- [([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))],
- assemble_app_constr sym a_tparams (a_dict_params @ a_default_params))])
+ val a_rhs_core = assemble_iterm false (SOME a_fallback_rhs) rhs;
+ val a_rhs = if null samepairs then a_rhs_core
+ else ml_if (ml_and (map nbe_same samepairs)) a_rhs_core a_fallback_rhs;
+ val a_fallback_eqn = if forall Code_Thingol.is_IVar args then NONE
+ else SOME (replicate (length a_global_params) "_", a_fallback_rhs);
+ in (map (assemble_iterm true NONE) args, a_rhs) :: the_list a_fallback_eqn end;
- fun assemble_value_eqn sym a_tparams a_dict_params (([], args), rhs) =
- (fun_ident 0 sym,
- [([ml_list a_tparams, ml_list (rev (a_dict_params @ assemble_args args))], assemble_rhs NONE rhs)]);
+ fun assemble_fallback_fundefs sym a_tparams a_dict_params a_global_params [(([], []), rhs)] =
+ assemble_iterm false NONE rhs
+ | assemble_fallback_fundefs sym a_tparams a_dict_params a_global_params eqns =
+ let
+ val a_fallback_syms = map_range (fn i => fallback_fun_ident i sym) (length eqns);
+ val a_fallback_rhss =
+ map_range (fn i => assemble_app_fallback_fun (i + 1) sym a_global_params) (length eqns - 1)
+ @ [assemble_app_const sym a_tparams (a_dict_params @ a_global_params)];
+ in
+ ml_let (ml_fundefs (a_fallback_syms ~~
+ map2 (assemble_fallback_fundef sym a_global_params) eqns a_fallback_rhss))
+ (assemble_app_fallback_fun 0 sym a_global_params)
+ end;
- fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, eqns, default_params))) =
- (if Code_Symbol.is_value sym then [assemble_value_eqn sym a_tparams a_dict_params (the_single eqns)]
- else map_index (assemble_eqn sym a_tparams a_dict_params default_params) eqns
- @ [assemble_default_eqn sym a_tparams a_dict_params default_params (length eqns)],
- ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun 0 sym a_tparams)));
+ fun assemble_fundef (sym, (num_args, (a_tparams, a_dict_params, a_global_params, eqns))) =
+ let
+ val a_lhs = [ml_list a_tparams, ml_list (rev (a_dict_params @ a_global_params))];
+ val a_rhs = assemble_fallback_fundefs sym a_tparams a_dict_params a_global_params eqns;
+ val a_univ = ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun sym a_tparams));
+ in
+ ((fun_ident sym, [(a_lhs, a_rhs), (a_lhs, ml_exc (fun_ident sym))]), a_univ)
+ end;
- val (fun_vars, fun_vals) = map_split assemble_eqns eqnss;
- val deps_vars = ml_list (map (fun_ident 0) deps);
- in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
+ val (a_fun_defs, a_fun_vals) = map_split assemble_fundef eqnss;
+ val dep_params = ml_list (map fun_ident deps);
+ in ml_abs dep_params (ml_let (ml_fundefs a_fun_defs) (ml_list a_fun_vals)) end;
fun assemble_eqnss ctxt idx_of_sym deps eqnss =
assemble_preprocessed_eqnss ctxt idx_of_sym deps (map preprocess_eqns eqnss);