# HG changeset patch # User haftmann # Date 1743167617 -3600 # Node ID 2d0721461810de08e885f674c0c50c9b7f960b83 # Parent 2819f79792b968ead3c3c1dbca71d25e69e68452 tuned diff -r 2819f79792b9 -r 2d0721461810 src/Tools/nbe.ML --- a/src/Tools/nbe.ML Fri Mar 28 14:13:36 2025 +0100 +++ b/src/Tools/nbe.ML Fri Mar 28 14:13:37 2025 +0100 @@ -255,9 +255,6 @@ val univs_cookie = (get_result, put_result, name_put); -fun nbe_fun idx_of 0 (Code_Symbol.Constant "") = "nbe_value" - | nbe_fun idx_of i sym = "c_" ^ string_of_int (idx_of sym) - ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i; fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n; fun nbe_bound v = "v_" ^ v; fun nbe_bound_optional NONE = "_" @@ -267,13 +264,8 @@ (*note: these three are the "turning spots" where proper argument order is established!*) fun nbe_apps t [] = t | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)]; -fun nbe_apps_local idx_of i c ts = nbe_fun idx_of i c `$` ml_list (rev ts); -fun nbe_apps_constr ctxt idx_of c ts = - let - val c' = if Config.get ctxt trace - then string_of_int (idx_of c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)" - else string_of_int (idx_of c); - in name_const `$` ("(" ^ c' ^ ", " ^ ml_list (rev ts) ^ ")") end; +fun nbe_apps_local c ts = c `$` ml_list (rev ts); +fun nbe_apps_constr c ts = name_const `$` ("(" ^ c ^ ", " ^ ml_list (rev ts) ^ ")"); fun nbe_abss 0 f = f `$` ml_list [] | nbe_abss n f = name_abss `$$` [string_of_int n, f]; @@ -288,26 +280,37 @@ (* code generation *) -fun assemble_eqnss ctxt idx_of deps eqnss = +fun assemble_eqnss ctxt idx_of_const deps eqnss = let - fun prep_eqns (c, (vs, eqns)) = + fun prep_eqns (sym, (vs, eqns)) = let val dicts = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs; val num_args = length dicts + ((length o fst o hd) eqns); - in (c, (num_args, (dicts, eqns))) end; + in (sym, (num_args, (dicts, eqns))) end; val eqnss' = map prep_eqns eqnss; + fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value" + | fun_ident i sym = "c_" ^ string_of_int (idx_of_const sym) + ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i; + fun constr_fun_ident c = + if Config.get ctxt trace + then string_of_int (idx_of_const c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)" + else string_of_int (idx_of_const c); + + fun apply_local i sym = nbe_apps_local (fun_ident i sym); + fun apply_constr sym = nbe_apps_constr (constr_fun_ident sym); + fun assemble_constapp sym dss ts = let val ts' = (maps o map) assemble_dict dss @ ts; in case AList.lookup (op =) eqnss' sym of SOME (num_args, _) => if num_args <= length ts' then let val (ts1, ts2) = chop num_args ts' - in nbe_apps (nbe_apps_local idx_of 0 sym ts1) ts2 - end else nbe_apps (nbe_abss num_args (nbe_fun idx_of 0 sym)) ts' + in nbe_apps (apply_local 0 sym ts1) ts2 + end else nbe_apps (nbe_abss num_args (fun_ident 0 sym)) ts' | NONE => if member (op =) deps sym - then nbe_apps (nbe_fun idx_of 0 sym) ts' - else nbe_apps_constr ctxt idx_of sym ts' + then nbe_apps (fun_ident 0 sym) ts' + else apply_constr sym ts' end and assemble_classrels classrels = fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] o single) classrels @@ -359,9 +362,9 @@ fun assemble_eqn sym dicts default_args (i, (args, rhs)) = let val match_cont = if Code_Symbol.is_value sym then NONE - else SOME (nbe_apps_local idx_of (i + 1) sym (dicts @ default_args)); + else SOME (apply_local (i + 1) sym (dicts @ default_args)); val assemble_arg = assemble_iterm - (fn sym' => fn dss => fn ts => nbe_apps_constr ctxt idx_of sym' ((maps o map) (K "_") + (fn sym' => fn dss => fn ts => apply_constr sym' ((maps o map) (K "_") dss @ ts)) NONE; val assemble_rhs = assemble_iterm assemble_constapp match_cont; val (samepairs, args') = subst_nonlin_vars args; @@ -374,19 +377,19 @@ | SOME default_rhs => [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs), ([ml_list (rev (dicts @ default_args))], default_rhs)] - in (nbe_fun idx_of i sym, eqns) end; + in (fun_ident i sym, eqns) end; fun assemble_eqns (sym, (num_args, (dicts, eqns))) = let val default_args = map nbe_default (Name.invent_global "a" (num_args - length dicts)); val eqns' = map_index (assemble_eqn sym dicts default_args) eqns - @ (if Code_Symbol.is_value sym then [] else [(nbe_fun idx_of (length eqns) sym, + @ (if Code_Symbol.is_value sym then [] else [(fun_ident (length eqns) sym, [([ml_list (rev (dicts @ default_args))], - nbe_apps_constr ctxt idx_of sym (dicts @ default_args))])]); - in (eqns', nbe_abss num_args (nbe_fun idx_of 0 sym)) end; + apply_constr sym (dicts @ default_args))])]); + in (eqns', nbe_abss num_args (fun_ident 0 sym)) end; val (fun_vars, fun_vals) = map_split assemble_eqns eqnss'; - val deps_vars = ml_list (map (nbe_fun idx_of 0) deps); + val deps_vars = ml_list (map (fun_ident 0) deps); in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end; @@ -397,19 +400,19 @@ let val (deps, deps_vals) = split_list (map_filter (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps); - val idx_of = raw_deps + val idx_of_const = raw_deps |> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep))) |> AList.lookup (op =) |> (fn f => the o f); - val s = assemble_eqnss ctxt idx_of deps eqnss; - val cs = map fst eqnss; + val s = assemble_eqnss ctxt idx_of_const deps eqnss; + val syms = map fst eqnss; in s |> traced ctxt (fn s => "\n--- code to be evaluated:\n" ^ s) |> pair "" |> Code_Runtime.value ctxt univs_cookie |> (fn f => f deps_vals) - |> (fn univs => cs ~~ univs) + |> (fn univs => syms ~~ univs) end; @@ -450,11 +453,11 @@ (* compilation of whole programs *) -fun ensure_const_idx name (nbe_program, (maxidx, idx_tab)) = +fun ensure_const_idx name (nbe_program, (maxidx, const_tab)) = if can (Code_Symbol.Graph.get_node nbe_program) name - then (nbe_program, (maxidx, idx_tab)) + then (nbe_program, (maxidx, const_tab)) else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program, - (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab)); + (maxidx + 1, Inttab.update_new (maxidx, name) const_tab)); fun compile_stmts ctxt stmts_deps = let @@ -472,16 +475,16 @@ fold ensure_const_idx refl_deps #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps #> compile - #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ)))) + #-> fold (fn (sym, univ) => (Code_Symbol.Graph.map_node sym o apfst) (K (SOME univ)))) end; fun compile_program { ctxt, program } = let - fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names - then (nbe_program, (maxidx, idx_tab)) - else (nbe_program, (maxidx, idx_tab)) - |> compile_stmts ctxt (map (fn name => ((name, Code_Symbol.Graph.get_node program name), - Code_Symbol.Graph.immediate_succs program name)) names); + fun add_stmts names (nbe_program, (maxidx, const_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names + then (nbe_program, (maxidx, const_tab)) + else (nbe_program, (maxidx, const_tab)) + |> compile_stmts ctxt (map (fn sym => ((sym, Code_Symbol.Graph.get_node program sym), + Code_Symbol.Graph.immediate_succs program sym)) names); in fold_rev add_stmts (Code_Symbol.Graph.strong_conn program) end; @@ -501,18 +504,18 @@ |> (fn t => apps t (rev dict_frees)) end; -fun reconstruct_term ctxt (idx_tab : Code_Symbol.T Inttab.table) t = +fun reconstruct_term ctxt (const_tab : Code_Symbol.T Inttab.table) t = let fun take_until f [] = [] | take_until f (x :: xs) = if f x then [] else x :: take_until f xs; fun is_dict (Const (idx, _)) = - (case Inttab.lookup idx_tab idx of + (case Inttab.lookup const_tab idx of SOME (Constant _) => false | _ => true) | is_dict (DFree _) = true | is_dict _ = false; fun const_of_idx idx = - case Inttab.lookup idx_tab idx of SOME (Constant const) => const; + case Inttab.lookup const_tab idx of SOME (Constant const) => const; fun of_apps bounds (t, ts) = fold_map (of_univ bounds) ts #>> (fn ts' => list_comb (t, rev ts')) @@ -533,12 +536,12 @@ |-> (fn t' => pair (Term.Abs ("u", dummyT, t'))) in of_univ 0 t 0 |> fst end; -fun compile_and_reconstruct_term { ctxt, nbe_program, idx_tab, deps, term } = +fun compile_and_reconstruct_term { ctxt, nbe_program, const_tab, deps, term } = compile_term { ctxt = ctxt, nbe_program = nbe_program, deps = deps, term = term } - |> reconstruct_term ctxt idx_tab; + |> reconstruct_term ctxt const_tab; -fun normalize_term (nbe_program, idx_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps = +fun normalize_term (nbe_program, const_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps = let val ctxt = Syntax.init_pretty_global (Proof_Context.theory_of raw_ctxt); val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt); @@ -553,7 +556,7 @@ else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t'); in Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term - { ctxt = ctxt, nbe_program = nbe_program, idx_tab = idx_tab, deps = deps, term = (vs, t) } + { ctxt = ctxt, nbe_program = nbe_program, const_tab = const_tab, deps = deps, term = (vs, t) } |> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t) |> type_infer |> traced ctxt (fn t => "Types inferred:\n" ^ string_of_term t) @@ -572,11 +575,11 @@ fun compile ignore_cache ctxt program = let - val (nbe_program, (_, idx_tab)) = + val (nbe_program, (_, const_tab)) = Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt)) (Code_Preproc.timed "compiling NBE program" #ctxt compile_program { ctxt = ctxt, program = program }); - in (nbe_program, idx_tab) end; + in (nbe_program, const_tab) end; (* evaluation oracle *) @@ -590,11 +593,11 @@ val (_, raw_oracle) = Theory.setup_result (Thm.add_oracle (\<^binding>\normalization_by_evaluation\, - fn (nbe_program_idx_tab, ctxt, vs_ty_t, deps, ct) => - mk_equals ctxt ct (normalize_term nbe_program_idx_tab ctxt (Thm.term_of ct) vs_ty_t deps))); + fn (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct) => + mk_equals ctxt ct (normalize_term nbe_program_const_tab ctxt (Thm.term_of ct) vs_ty_t deps))); -fun oracle nbe_program_idx_tab ctxt vs_ty_t deps ct = - raw_oracle (nbe_program_idx_tab, ctxt, vs_ty_t, deps, ct); +fun oracle nbe_program_const_tab ctxt vs_ty_t deps ct = + raw_oracle (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct); fun dynamic_conv ctxt = lift_triv_classes_conv ctxt (fn ctxt' => Code_Thingol.dynamic_conv ctxt' (fn program =>