tuned
authorhaftmann
Fri, 28 Mar 2025 14:13:37 +0100
changeset 82374 2d0721461810
parent 82373 2819f79792b9
child 82375 1972ae7da0d2
tuned
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Fri Mar 28 14:13:36 2025 +0100
+++ b/src/Tools/nbe.ML	Fri Mar 28 14:13:37 2025 +0100
@@ -255,9 +255,6 @@
 
 val univs_cookie = (get_result, put_result, name_put);
 
-fun nbe_fun idx_of 0 (Code_Symbol.Constant "") = "nbe_value"
-  | nbe_fun idx_of i sym = "c_" ^ string_of_int (idx_of sym)
-      ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
 fun nbe_bound v = "v_" ^ v;
 fun nbe_bound_optional NONE = "_"
@@ -267,13 +264,8 @@
 (*note: these three are the "turning spots" where proper argument order is established!*)
 fun nbe_apps t [] = t
   | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
-fun nbe_apps_local idx_of i c ts = nbe_fun idx_of i c `$` ml_list (rev ts);
-fun nbe_apps_constr ctxt idx_of c ts =
-  let
-    val c' = if Config.get ctxt trace
-      then string_of_int (idx_of c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)"
-      else string_of_int (idx_of c);
-  in name_const `$` ("(" ^ c' ^ ", " ^ ml_list (rev ts) ^ ")") end;
+fun nbe_apps_local c ts = c `$` ml_list (rev ts);
+fun nbe_apps_constr c ts = name_const `$` ("(" ^ c ^ ", " ^ ml_list (rev ts) ^ ")");
 
 fun nbe_abss 0 f = f `$` ml_list []
   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
@@ -288,26 +280,37 @@
 
 (* code generation *)
 
-fun assemble_eqnss ctxt idx_of deps eqnss =
+fun assemble_eqnss ctxt idx_of_const deps eqnss =
   let
-    fun prep_eqns (c, (vs, eqns)) =
+    fun prep_eqns (sym, (vs, eqns)) =
       let
         val dicts = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
         val num_args = length dicts + ((length o fst o hd) eqns);
-      in (c, (num_args, (dicts, eqns))) end;
+      in (sym, (num_args, (dicts, eqns))) end;
     val eqnss' = map prep_eqns eqnss;
 
+    fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value"
+      | fun_ident i sym = "c_" ^ string_of_int (idx_of_const sym)
+          ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
+    fun constr_fun_ident c =
+      if Config.get ctxt trace
+      then string_of_int (idx_of_const c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)"
+      else string_of_int (idx_of_const c);
+
+    fun apply_local i sym = nbe_apps_local (fun_ident i sym);
+    fun apply_constr sym = nbe_apps_constr (constr_fun_ident sym);
+
     fun assemble_constapp sym dss ts = 
       let
         val ts' = (maps o map) assemble_dict dss @ ts;
       in case AList.lookup (op =) eqnss' sym
        of SOME (num_args, _) => if num_args <= length ts'
             then let val (ts1, ts2) = chop num_args ts'
-            in nbe_apps (nbe_apps_local idx_of 0 sym ts1) ts2
-            end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 sym)) ts'
+            in nbe_apps (apply_local 0 sym ts1) ts2
+            end else nbe_apps (nbe_abss num_args (fun_ident 0 sym)) ts'
         | NONE => if member (op =) deps sym
-            then nbe_apps (nbe_fun idx_of 0 sym) ts'
-            else nbe_apps_constr ctxt idx_of sym ts'
+            then nbe_apps (fun_ident 0 sym) ts'
+            else apply_constr sym ts'
       end
     and assemble_classrels classrels =
       fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] o single) classrels
@@ -359,9 +362,9 @@
     fun assemble_eqn sym dicts default_args (i, (args, rhs)) =
       let
         val match_cont = if Code_Symbol.is_value sym then NONE
-          else SOME (nbe_apps_local idx_of (i + 1) sym (dicts @ default_args));
+          else SOME (apply_local (i + 1) sym (dicts @ default_args));
         val assemble_arg = assemble_iterm
-          (fn sym' => fn dss => fn ts => nbe_apps_constr ctxt idx_of sym' ((maps o map) (K "_")
+          (fn sym' => fn dss => fn ts => apply_constr sym' ((maps o map) (K "_")
             dss @ ts)) NONE;
         val assemble_rhs = assemble_iterm assemble_constapp match_cont;
         val (samepairs, args') = subst_nonlin_vars args;
@@ -374,19 +377,19 @@
           | SOME default_rhs =>
               [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
                 ([ml_list (rev (dicts @ default_args))], default_rhs)]
-      in (nbe_fun idx_of i sym, eqns) end;
+      in (fun_ident i sym, eqns) end;
 
     fun assemble_eqns (sym, (num_args, (dicts, eqns))) =
       let
         val default_args = map nbe_default (Name.invent_global "a" (num_args - length dicts));
         val eqns' = map_index (assemble_eqn sym dicts default_args) eqns
-          @ (if Code_Symbol.is_value sym then [] else [(nbe_fun idx_of (length eqns) sym,
+          @ (if Code_Symbol.is_value sym then [] else [(fun_ident (length eqns) sym,
             [([ml_list (rev (dicts @ default_args))],
-              nbe_apps_constr ctxt idx_of sym (dicts @ default_args))])]);
-      in (eqns', nbe_abss num_args (nbe_fun idx_of 0 sym)) end;
+              apply_constr sym (dicts @ default_args))])]);
+      in (eqns', nbe_abss num_args (fun_ident 0 sym)) end;
 
     val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
-    val deps_vars = ml_list (map (nbe_fun idx_of 0) deps);
+    val deps_vars = ml_list (map (fun_ident 0) deps);
   in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
 
 
@@ -397,19 +400,19 @@
       let
         val (deps, deps_vals) = split_list (map_filter
           (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps);
-        val idx_of = raw_deps
+        val idx_of_const = raw_deps
           |> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep)))
           |> AList.lookup (op =)
           |> (fn f => the o f);
-        val s = assemble_eqnss ctxt idx_of deps eqnss;
-        val cs = map fst eqnss;
+        val s = assemble_eqnss ctxt idx_of_const deps eqnss;
+        val syms = map fst eqnss;
       in
         s
         |> traced ctxt (fn s => "\n--- code to be evaluated:\n" ^ s)
         |> pair ""
         |> Code_Runtime.value ctxt univs_cookie
         |> (fn f => f deps_vals)
-        |> (fn univs => cs ~~ univs)
+        |> (fn univs => syms ~~ univs)
       end;
 
 
@@ -450,11 +453,11 @@
 
 (* compilation of whole programs *)
 
-fun ensure_const_idx name (nbe_program, (maxidx, idx_tab)) =
+fun ensure_const_idx name (nbe_program, (maxidx, const_tab)) =
   if can (Code_Symbol.Graph.get_node nbe_program) name
-  then (nbe_program, (maxidx, idx_tab))
+  then (nbe_program, (maxidx, const_tab))
   else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program,
-    (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
+    (maxidx + 1, Inttab.update_new (maxidx, name) const_tab));
 
 fun compile_stmts ctxt stmts_deps =
   let
@@ -472,16 +475,16 @@
     fold ensure_const_idx refl_deps
     #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
       #> compile
-      #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
+      #-> fold (fn (sym, univ) => (Code_Symbol.Graph.map_node sym o apfst) (K (SOME univ))))
   end;
 
 fun compile_program { ctxt, program } =
   let
-    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
-      then (nbe_program, (maxidx, idx_tab))
-      else (nbe_program, (maxidx, idx_tab))
-        |> compile_stmts ctxt (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
-          Code_Symbol.Graph.immediate_succs program name)) names);
+    fun add_stmts names (nbe_program, (maxidx, const_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
+      then (nbe_program, (maxidx, const_tab))
+      else (nbe_program, (maxidx, const_tab))
+        |> compile_stmts ctxt (map (fn sym => ((sym, Code_Symbol.Graph.get_node program sym),
+          Code_Symbol.Graph.immediate_succs program sym)) names);
   in
     fold_rev add_stmts (Code_Symbol.Graph.strong_conn program)
   end;
@@ -501,18 +504,18 @@
     |> (fn t => apps t (rev dict_frees))
   end;
 
-fun reconstruct_term ctxt (idx_tab : Code_Symbol.T Inttab.table) t =
+fun reconstruct_term ctxt (const_tab : Code_Symbol.T Inttab.table) t =
   let
     fun take_until f [] = []
       | take_until f (x :: xs) = if f x then [] else x :: take_until f xs;
     fun is_dict (Const (idx, _)) =
-          (case Inttab.lookup idx_tab idx of
+          (case Inttab.lookup const_tab idx of
             SOME (Constant _) => false
           | _ => true)
       | is_dict (DFree _) = true
       | is_dict _ = false;
     fun const_of_idx idx =
-      case Inttab.lookup idx_tab idx of SOME (Constant const) => const;
+      case Inttab.lookup const_tab idx of SOME (Constant const) => const;
     fun of_apps bounds (t, ts) =
       fold_map (of_univ bounds) ts
       #>> (fn ts' => list_comb (t, rev ts'))
@@ -533,12 +536,12 @@
           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   in of_univ 0 t 0 |> fst end;
 
-fun compile_and_reconstruct_term { ctxt, nbe_program, idx_tab, deps, term } =
+fun compile_and_reconstruct_term { ctxt, nbe_program, const_tab, deps, term } =
   compile_term
     { ctxt = ctxt, nbe_program = nbe_program, deps = deps, term = term }
-  |> reconstruct_term ctxt idx_tab;
+  |> reconstruct_term ctxt const_tab;
 
-fun normalize_term (nbe_program, idx_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps =
+fun normalize_term (nbe_program, const_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps =
   let
     val ctxt = Syntax.init_pretty_global (Proof_Context.theory_of raw_ctxt);
     val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
@@ -553,7 +556,7 @@
       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t');
   in
     Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term
-      { ctxt = ctxt, nbe_program = nbe_program, idx_tab = idx_tab, deps = deps, term = (vs, t) }
+      { ctxt = ctxt, nbe_program = nbe_program, const_tab = const_tab, deps = deps, term = (vs, t) }
     |> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t)
     |> type_infer
     |> traced ctxt (fn t => "Types inferred:\n" ^ string_of_term t)
@@ -572,11 +575,11 @@
 
 fun compile ignore_cache ctxt program =
   let
-    val (nbe_program, (_, idx_tab)) =
+    val (nbe_program, (_, const_tab)) =
       Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt))
         (Code_Preproc.timed "compiling NBE program" #ctxt
           compile_program { ctxt = ctxt, program = program });
-  in (nbe_program, idx_tab) end;
+  in (nbe_program, const_tab) end;
 
 
 (* evaluation oracle *)
@@ -590,11 +593,11 @@
 
 val (_, raw_oracle) =
   Theory.setup_result (Thm.add_oracle (\<^binding>\<open>normalization_by_evaluation\<close>,
-    fn (nbe_program_idx_tab, ctxt, vs_ty_t, deps, ct) =>
-      mk_equals ctxt ct (normalize_term nbe_program_idx_tab ctxt (Thm.term_of ct) vs_ty_t deps)));
+    fn (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct) =>
+      mk_equals ctxt ct (normalize_term nbe_program_const_tab ctxt (Thm.term_of ct) vs_ty_t deps)));
 
-fun oracle nbe_program_idx_tab ctxt vs_ty_t deps ct =
-  raw_oracle (nbe_program_idx_tab, ctxt, vs_ty_t, deps, ct);
+fun oracle nbe_program_const_tab ctxt vs_ty_t deps ct =
+  raw_oracle (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct);
 
 fun dynamic_conv ctxt = lift_triv_classes_conv ctxt
   (fn ctxt' => Code_Thingol.dynamic_conv ctxt' (fn program =>