revamped generation of functions
authorhaftmann
Mon, 14 Apr 2025 20:19:05 +0200
changeset 82506 289b18955960
parent 82505 fa641833c0ff
child 82507 38550f21275d
revamped generation of functions
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Mon Apr 14 20:19:05 2025 +0200
+++ b/src/Tools/nbe.ML	Mon Apr 14 20:19:05 2025 +0200
@@ -221,19 +221,20 @@
 
 fun ml_cases e cs = enclose "(" ")"
   ("case " ^ e ^ " of " ^ space_implode " | " (map (fn (p, e) => p ^ " => " ^ e) cs));
-fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
-fun ml_as v t = enclose "(" ")" (v ^ " as " ^ t);
+fun ml_let d e = "\n  let\n" ^ prefix_lines "    " d ^ "\n  in " ^ e ^ " end";
 
 fun ml_and [] = "true"
   | ml_and [e] = e
   | ml_and es = enclose "(" ")" (space_implode " andalso " es);
-fun ml_if b e1 e2 = enclose "(" ")" ("if" ^ b ^ " then " ^ e1 ^ " else " ^ e2);
+fun ml_if b e1 e2 = enclose "(" ")" (implode_space ["if", b, "then", e1, "else", e2]);
 
 fun e1 `*` e2 = enclose "(" ")" (e1 ^ ", " ^ e2);
 fun ml_list es = enclose "[" "]" (commas es);
 
+fun ml_exc s = enclose "(" ")" ("raise Fail " ^ quote s);
+
 fun ml_fundefs ([(name, [([], e)])]) =
-      "val " ^ name ^ " = " ^ e ^ "\n"
+      "val " ^ name ^ " = " ^ e ^ ""
   | ml_fundefs (eqs :: eqss) =
       let
         fun fundef (name, eqs) =
@@ -243,7 +244,6 @@
       in
         (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
         |> cat_lines
-        |> suffix "\n"
       end;
 
 
@@ -278,17 +278,18 @@
 
 fun nbe_tparam v = "t_" ^ v;
 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
+fun nbe_global_param v = "w_" ^ v;
 fun nbe_bound v = "v_" ^ v;
 fun nbe_bound_optional NONE = "_"
   | nbe_bound_optional (SOME v) = nbe_bound v;
-fun nbe_default v = "w_" ^ v;
 fun nbe_type a_sym a_tys = name_type `$` (quote a_sym `*` ml_list a_tys);
 fun nbe_fun a_sym a_typargs = a_sym `$` ml_list a_typargs;
 
-(*note: these three are the "turning spots" where proper argument order is established!*)
+(*note: these are the "turning spots" where proper argument order is established!*)
 fun nbe_apps a_u [] = a_u
   | nbe_apps a_u a_us = name_apps `$$` [a_u, ml_list (rev a_us)];
 fun nbe_apps_fun a_sym a_typargs a_us = nbe_fun a_sym a_typargs `$` ml_list (rev a_us);
+fun nbe_apps_seq_fun a_sym a_us = a_sym `$$` a_us;
 fun nbe_apps_constr a_sym a_typargs a_us = name_const `$` ((a_sym `*` ml_list a_typargs) `*` ml_list (rev a_us));
 fun nbe_apps_constmatch a_sym a_us = name_const `$` ((a_sym `*` "_") `*` ml_list (rev a_us));
 
@@ -332,24 +333,26 @@
     val a_tparams = map (fn (v, _) => nbe_tparam v) vs;
     val a_dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
     val num_args = length a_dict_params + ((length o fst o hd) eqns);
-    val a_default_params = map nbe_default (Name.invent_global "a" (num_args - length a_dict_params));
-  in (sym, (num_args, (a_tparams, a_dict_params, (map o apfst) subst_nonlin_vars eqns, a_default_params))) end;
+    val a_global_params = map nbe_global_param (Name.invent_global "a" (num_args - length a_dict_params));
+  in (sym, (num_args, (a_tparams, a_dict_params, a_global_params, (map o apfst) subst_nonlin_vars eqns))) end;
 
 fun assemble_type (tyco `%% tys) = nbe_type tyco (map assemble_type tys)
   | assemble_type (ITyVar v) = nbe_tparam v
 
 fun assemble_preprocessed_eqnss ctxt idx_of_sym deps eqnss =
   let
-    fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value" 
-      | fun_ident i sym = "c_" ^ string_of_int (idx_of_sym sym)
-          ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
+    fun fun_ident sym = space_implode "_"
+      ["c", if Code_Symbol.is_value sym then "value" else string_of_int (idx_of_sym sym), Code_Symbol.default_base sym, "nbe"];
+    fun seq_fun_ident i sym = space_implode "_"
+      ["c", string_of_int (idx_of_sym sym), Code_Symbol.default_base sym, string_of_int i];
     fun constr_ident sym =
       if Config.get ctxt trace
       then string_of_int (idx_of_sym sym) ^ " (*" ^ Code_Symbol.default_base sym ^ "*)"
       else string_of_int (idx_of_sym sym);
 
-    fun assemble_fun i sym = nbe_fun (fun_ident i sym);
-    fun assemble_app_fun i sym = nbe_apps_fun (fun_ident i sym);
+    fun assemble_fun sym = nbe_fun (fun_ident sym);
+    fun assemble_app_fun sym = nbe_apps_fun (fun_ident sym);
+    fun assemble_app_seq_fun i sym = nbe_apps_seq_fun (seq_fun_ident i sym);
     fun assemble_app_constr sym = nbe_apps_constr (constr_ident sym);
     fun assemble_app_constmatch sym = nbe_apps_constmatch (constr_ident sym);
 
@@ -360,10 +363,10 @@
       in case AList.lookup (op =) eqnss sym
        of SOME (num_args, _) => if num_args <= length a_ts'
             then let val (a_ts1, a_ts2) = chop num_args a_ts'
-            in nbe_apps (assemble_app_fun 0 sym a_typargs a_ts1) a_ts2
-            end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym a_typargs)) a_ts'
+            in nbe_apps (assemble_app_fun sym a_typargs a_ts1) a_ts2
+            end else nbe_apps (nbe_abss num_args (assemble_fun sym a_typargs)) a_ts'
         | NONE => if member (op =) deps sym
-            then nbe_apps (assemble_fun 0 sym a_typargs) a_ts'
+            then nbe_apps (assemble_fun sym a_typargs) a_ts'
             else assemble_app_constr sym a_typargs a_ts'
       end
     and assemble_classrels classrels =
@@ -397,35 +400,42 @@
     val assemble_args = map (assemble_iterm assemble_constmatch NONE);
     val assemble_rhs = assemble_iterm assemble_constapp;
 
-    fun assemble_eqn sym a_tparams a_dict_params a_default_params (i, ((samepairs, args), rhs)) =
+    fun assemble_eqn sym a_global_params (i, ((samepairs, args), rhs)) =
       let
-        val default_rhs = assemble_app_fun (i + 1) sym a_tparams (a_dict_params @ a_default_params);
-        val s_args = assemble_args args;
-        val s_rhs = if null samepairs then assemble_rhs (SOME default_rhs) rhs
+        val a_fallback_rhs = assemble_app_seq_fun (i + 1) sym a_global_params;
+        val a_args = assemble_args args;
+        val a_rhs = if null samepairs then assemble_rhs (SOME a_fallback_rhs) rhs
           else ml_if (ml_and (map nbe_same samepairs))
-            (assemble_rhs (SOME default_rhs) rhs) default_rhs;
-        val eqns = [([ml_list a_tparams, ml_list (rev (a_dict_params @ map2 ml_as a_default_params s_args))], s_rhs),
-          ([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))], default_rhs)]
-      in (fun_ident i sym, eqns) end;
+            (assemble_rhs (SOME a_fallback_rhs) rhs) a_fallback_rhs;
+        val fallback_eqn = if forall Code_Thingol.is_IVar args then []
+          else [(replicate (length a_global_params) "_", a_fallback_rhs)];
+      in
+        (seq_fun_ident i sym, (a_args, a_rhs) :: fallback_eqn)
+      end;
+
+    fun assemble_default_eqn sym a_tparams a_dict_params a_global_params i =
+      (seq_fun_ident i sym,
+        [(replicate (length a_global_params) "_", assemble_app_constr sym a_tparams (a_dict_params @ a_global_params))])
 
-    fun assemble_default_eqn sym a_tparams a_dict_params a_default_params i =
-      (fun_ident i sym,
-        [([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))],
-          assemble_app_constr sym a_tparams (a_dict_params @ a_default_params))])
-
-    fun assemble_value_eqn sym a_tparams a_dict_params (([], args), rhs) =
-      (fun_ident 0 sym,
-        [([ml_list a_tparams, ml_list (rev (a_dict_params @ assemble_args args))], assemble_rhs NONE rhs)]);
+    fun assemble_seq_eqns sym a_tparams a_dict_params a_global_params [(([], []), rhs)] =
+          assemble_rhs NONE rhs
+      | assemble_seq_eqns sym a_tparams a_dict_params a_global_params eqns =
+          ml_let (ml_fundefs (map_index (assemble_eqn sym a_global_params) eqns
+            @ [assemble_default_eqn sym a_tparams a_dict_params a_global_params (length eqns)]))
+          (assemble_app_seq_fun 0 sym a_global_params);
 
-    fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, eqns, default_params))) =
-      (if Code_Symbol.is_value sym then [assemble_value_eqn sym a_tparams a_dict_params (the_single eqns)]
-      else map_index (assemble_eqn sym a_tparams a_dict_params default_params) eqns
-        @ [assemble_default_eqn sym a_tparams a_dict_params default_params (length eqns)],
-      ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun 0 sym a_tparams)));
+    fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, a_global_params, eqns))) =
+      let
+        val a_lhs = [ml_list a_tparams, ml_list (rev (a_dict_params @ a_global_params))];
+        val a_rhs = assemble_seq_eqns sym a_tparams a_dict_params a_global_params eqns;
+        val a_univ = ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun sym a_tparams));
+      in
+        ((fun_ident sym, [(a_lhs, a_rhs), (a_lhs, ml_exc (fun_ident sym))]), a_univ)
+      end;
 
-    val (fun_vars, fun_vals) = map_split assemble_eqns eqnss;
-    val deps_vars = ml_list (map (fun_ident 0) deps);
-  in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
+    val (a_fun_defs, a_fun_vals) = map_split assemble_eqns eqnss;
+    val dep_params = ml_list (map fun_ident deps);
+  in ml_abs dep_params (ml_let (ml_fundefs a_fun_defs) (ml_list a_fun_vals)) end;
 
 fun assemble_eqnss ctxt idx_of_sym deps eqnss =
   assemble_preprocessed_eqnss ctxt idx_of_sym deps (map preprocess_eqns eqnss);