# HG changeset patch # User haftmann # Date 1743955973 -7200 # Node ID ba1f9fb23b8d76ab4f2722b182271fc5484c26f9 # Parent f9e6cbc6bf22947f5b8bf10c0ff578088a3c9a02 clarified variable names diff -r f9e6cbc6bf22 -r ba1f9fb23b8d src/Tools/nbe.ML --- a/src/Tools/nbe.ML Sun Apr 06 18:12:52 2025 +0200 +++ b/src/Tools/nbe.ML Sun Apr 06 18:12:53 2025 +0200 @@ -184,28 +184,27 @@ | Abs of (int * (Univ list -> Univ)) * Univ list (*abstractions as closures*) - (* constructor functions *) fun abss n f = Abs ((n, f), []); -fun apps (Abs ((n, f), xs)) ys = let val k = n - length ys in +fun apps (Abs ((n, f), us)) ws = let val k = n - length ws in case int_ord (k, 0) - of EQUAL => f (ys @ xs) - | LESS => let val (zs, ws) = chop (~ k) ys in apps (f (ws @ xs)) zs end - | GREATER => Abs ((k, f), ys @ xs) (*note: reverse convention also for apps!*) + of EQUAL => f (ws @ us) + | LESS => let val (ws1, ws2) = chop (~ k) ws in apps (f (ws2 @ us)) ws1 end + | GREATER => Abs ((k, f), ws @ us) (*note: reverse convention also for apps!*) end | apps (Const (name, xs)) ys = Const (name, ys @ xs) | apps (BVar (n, xs)) ys = BVar (n, ys @ xs); -fun same_type (Type (tyco1, types1), Type (tyco2, types2)) = - (tyco1 = tyco2) andalso eq_list same_type (types1, types2) +fun same_type (Type (tyco1, ty1), Type (tyco2, tys2)) = + (tyco1 = tyco2) andalso eq_list same_type (ty1, tys2) | same_type (TParam v1, TParam v2) = (v1 = v2) | same_type _ = false; -fun same (Const ((k1, ts1), xs1), Const ((k2, ts2), xs2)) = - (k1 = k2) andalso eq_list same_type (ts1, ts2) andalso eq_list same (xs1, xs2) +fun same (Const ((k1, typargs1), us1), Const ((k2, typargs2), us2)) = + (k1 = k2) andalso eq_list same_type (typargs1, typargs2) andalso eq_list same (us1, us2) | same (DFree (n1, i1), DFree (n2, i2)) = (n1 = n2) andalso (i1 = i2) - | same (BVar (i1, xs1), BVar (i2, xs2)) = (i1 = i2) andalso eq_list same (xs1, xs2) + | same (BVar (i1, us1), BVar (i2, us2)) = (i1 = i2) andalso eq_list same (us1, us2) | same _ = false; @@ -271,21 +270,27 @@ val univs_cookie = (get_result, put_result, name_put); +(* + Convention: parameters representing ("assembled") string representations of logical entities + are prefixed with an "a_" -- unless they are an unqualified name ready to become part of + an ML identifier. +*) + fun nbe_tparam v = "t_" ^ v; fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n; fun nbe_bound v = "v_" ^ v; fun nbe_bound_optional NONE = "_" | nbe_bound_optional (SOME v) = nbe_bound v; fun nbe_default v = "w_" ^ v; -fun nbe_type n ts = name_type `$` (quote n `*` ml_list ts); -fun nbe_fun c tys = c `$` ml_list tys; +fun nbe_type a_sym a_tys = name_type `$` (quote a_sym `*` ml_list a_tys); +fun nbe_fun a_sym a_typargs = a_sym `$` ml_list a_typargs; (*note: these three are the "turning spots" where proper argument order is established!*) -fun nbe_apps t [] = t - | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)]; -fun nbe_apps_fun c tys ts = nbe_fun c tys `$` ml_list (rev ts); -fun nbe_apps_constr c tys ts = name_const `$` ((c `*` ml_list tys) `*` ml_list (rev ts)); -fun nbe_apps_constmatch c ts = name_const `$` ((c `*` "_") `*` ml_list (rev ts)); +fun nbe_apps a_u [] = a_u + | nbe_apps a_u a_us = name_apps `$$` [a_u, ml_list (rev a_us)]; +fun nbe_apps_fun a_sym a_typargs a_us = nbe_fun a_sym a_typargs `$` ml_list (rev a_us); +fun nbe_apps_constr a_sym a_typargs a_us = name_const `$` ((a_sym `*` ml_list a_typargs) `*` ml_list (rev a_us)); +fun nbe_apps_constmatch a_sym a_us = name_const `$` ((a_sym `*` "_") `*` ml_list (rev a_us)); fun nbe_abss 0 f = f `$` ml_list [] | nbe_abss n f = name_abss `$$` [string_of_int n, f]; @@ -324,54 +329,54 @@ fun preprocess_eqns (sym, (vs, eqns)) = let - val s_tparams = map (fn (v, _) => nbe_tparam v) vs; - val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs; - val num_args = length dict_params + ((length o fst o hd) eqns); - val default_params = map nbe_default (Name.invent_global "a" (num_args - length dict_params)); - in (sym, (num_args, (s_tparams, dict_params, (map o apfst) subst_nonlin_vars eqns, default_params))) end; + val a_tparams = map (fn (v, _) => nbe_tparam v) vs; + val a_dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs; + val num_args = length a_dict_params + ((length o fst o hd) eqns); + val a_default_params = map nbe_default (Name.invent_global "a" (num_args - length a_dict_params)); + in (sym, (num_args, (a_tparams, a_dict_params, (map o apfst) subst_nonlin_vars eqns, a_default_params))) end; fun assemble_type (tyco `%% tys) = nbe_type tyco (map assemble_type tys) | assemble_type (ITyVar v) = nbe_tparam v -fun assemble_preprocessed_eqnss ctxt idx_of_const deps eqnss = +fun assemble_preprocessed_eqnss ctxt idx_of_sym deps eqnss = let fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value" - | fun_ident i sym = "c_" ^ string_of_int (idx_of_const sym) + | fun_ident i sym = "c_" ^ string_of_int (idx_of_sym sym) ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i; - fun constr_ident c = + fun constr_ident sym = if Config.get ctxt trace - then string_of_int (idx_of_const c) ^ " (*" ^ Code_Symbol.default_base c ^ "*)" - else string_of_int (idx_of_const c); + then string_of_int (idx_of_sym sym) ^ " (*" ^ Code_Symbol.default_base sym ^ "*)" + else string_of_int (idx_of_sym sym); fun assemble_fun i sym = nbe_fun (fun_ident i sym); fun assemble_app_fun i sym = nbe_apps_fun (fun_ident i sym); fun assemble_app_constr sym = nbe_apps_constr (constr_ident sym); fun assemble_app_constmatch sym = nbe_apps_constmatch (constr_ident sym); - fun assemble_constapp sym tys dictss ts = + fun assemble_constapp sym typargs dictss a_ts = let - val s_tys = map (assemble_type) tys; - val ts' = (maps o map) assemble_dict (map2 (fn ty => map (fn dict => (ty, dict))) tys dictss) @ ts; + val a_typargs = map (assemble_type) typargs; + val a_ts' = (maps o map) assemble_dict (map2 (fn ty => map (fn dict => (ty, dict))) typargs dictss) @ a_ts; in case AList.lookup (op =) eqnss sym - of SOME (num_args, _) => if num_args <= length ts' - then let val (ts1, ts2) = chop num_args ts' - in nbe_apps (assemble_app_fun 0 sym s_tys ts1) ts2 - end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym s_tys)) ts' + of SOME (num_args, _) => if num_args <= length a_ts' + then let val (a_ts1, a_ts2) = chop num_args a_ts' + in nbe_apps (assemble_app_fun 0 sym a_typargs a_ts1) a_ts2 + end else nbe_apps (nbe_abss num_args (assemble_fun 0 sym a_typargs)) a_ts' | NONE => if member (op =) deps sym - then nbe_apps (assemble_fun 0 sym s_tys) ts' - else assemble_app_constr sym s_tys ts' + then nbe_apps (assemble_fun 0 sym a_typargs) a_ts' + else assemble_app_constr sym a_typargs a_ts' end and assemble_classrels classrels = fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] [] o single) classrels - and assemble_dict (ty, Dict (classrels, x)) = - assemble_classrels classrels (assemble_plain_dict ty x) + and assemble_dict (ty, Dict (classrels, dict)) = + assemble_classrels classrels (assemble_plain_dict ty dict) and assemble_plain_dict (_ `%% tys) (Dict_Const (inst, dictss)) = assemble_constapp (Class_Instance inst) tys (map snd dictss) [] | assemble_plain_dict _ (Dict_Var { var, index, ... }) = nbe_dict var index - fun assemble_constmatch sym _ dictss ts = - assemble_app_constmatch sym ((maps o map) (K "_") dictss @ ts); + fun assemble_constmatch sym _ dictss a_ts = + assemble_app_constmatch sym ((maps o map) (K "_") dictss @ a_ts); fun assemble_iterm constapp = let @@ -379,50 +384,51 @@ let val (t', ts) = Code_Thingol.unfold_app t in of_iapp match_continuation t' (fold_rev (cons o of_iterm NONE) ts []) end - and of_iapp match_continuation (IConst { sym, typargs = tys, dictss, ... }) ts = constapp sym tys dictss ts - | of_iapp match_continuation (IVar v) ts = nbe_apps (nbe_bound_optional v) ts - | of_iapp match_continuation ((v, _) `|=> (t, _)) ts = - nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) ts - | of_iapp match_continuation (ICase { term = t, clauses = clauses, primitive = t0, ... }) ts = + and of_iapp match_continuation (IConst { sym, typargs, dictss, ... }) = constapp sym typargs dictss + | of_iapp match_continuation (IVar v) = nbe_apps (nbe_bound_optional v) + | of_iapp match_continuation ((v, _) `|=> (t, _)) = + nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) + | of_iapp match_continuation (ICase { term = t, clauses = clauses, primitive = t0, ... }) = nbe_apps (ml_cases (of_iterm NONE t) (map (fn (p, t) => (assemble_iterm assemble_constmatch NONE p, of_iterm match_continuation t)) clauses - @ [("_", case match_continuation of SOME s => s | NONE => of_iterm NONE t0)])) ts + @ [("_", case match_continuation of SOME s => s | NONE => of_iterm NONE t0)])) in of_iterm end; val assemble_args = map (assemble_iterm assemble_constmatch NONE); val assemble_rhs = assemble_iterm assemble_constapp; - fun assemble_eqn sym s_tparams dict_params default_params (i, ((samepairs, args), rhs)) = + fun assemble_eqn sym a_tparams a_dict_params a_default_params (i, ((samepairs, args), rhs)) = let - val default_rhs = assemble_app_fun (i + 1) sym s_tparams (dict_params @ default_params); + val default_rhs = assemble_app_fun (i + 1) sym a_tparams (a_dict_params @ a_default_params); val s_args = assemble_args args; val s_rhs = if null samepairs then assemble_rhs (SOME default_rhs) rhs else ml_if (ml_and (map nbe_same samepairs)) (assemble_rhs (SOME default_rhs) rhs) default_rhs; - val eqns = [([ml_list s_tparams, ml_list (rev (dict_params @ map2 ml_as default_params s_args))], s_rhs), - ([ml_list s_tparams, ml_list (rev (dict_params @ default_params))], default_rhs)] + val eqns = [([ml_list a_tparams, ml_list (rev (a_dict_params @ map2 ml_as a_default_params s_args))], s_rhs), + ([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))], default_rhs)] in (fun_ident i sym, eqns) end; - fun assemble_default_eqn sym s_tparams dict_params default_params i = + fun assemble_default_eqn sym a_tparams a_dict_params a_default_params i = (fun_ident i sym, - [([ml_list s_tparams, ml_list (rev (dict_params @ default_params))], assemble_app_constr sym s_tparams (dict_params @ default_params))]) + [([ml_list a_tparams, ml_list (rev (a_dict_params @ a_default_params))], + assemble_app_constr sym a_tparams (a_dict_params @ a_default_params))]) - fun assemble_value_eqn sym s_tparams dict_params (([], args), rhs) = + fun assemble_value_eqn sym a_tparams a_dict_params (([], args), rhs) = (fun_ident 0 sym, - [([ml_list s_tparams, ml_list (rev (dict_params @ assemble_args args))], assemble_rhs NONE rhs)]); + [([ml_list a_tparams, ml_list (rev (a_dict_params @ assemble_args args))], assemble_rhs NONE rhs)]); - fun assemble_eqns (sym, (num_args, (s_tparams, dict_params, eqns, default_params))) = - (if Code_Symbol.is_value sym then [assemble_value_eqn sym s_tparams dict_params (the_single eqns)] - else map_index (assemble_eqn sym s_tparams dict_params default_params) eqns - @ [assemble_default_eqn sym s_tparams dict_params default_params (length eqns)], - ml_abs (ml_list s_tparams) (nbe_abss num_args (assemble_fun 0 sym s_tparams))); + fun assemble_eqns (sym, (num_args, (a_tparams, a_dict_params, eqns, default_params))) = + (if Code_Symbol.is_value sym then [assemble_value_eqn sym a_tparams a_dict_params (the_single eqns)] + else map_index (assemble_eqn sym a_tparams a_dict_params default_params) eqns + @ [assemble_default_eqn sym a_tparams a_dict_params default_params (length eqns)], + ml_abs (ml_list a_tparams) (nbe_abss num_args (assemble_fun 0 sym a_tparams))); val (fun_vars, fun_vals) = map_split assemble_eqns eqnss; val deps_vars = ml_list (map (fun_ident 0) deps); in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end; -fun assemble_eqnss ctxt idx_of_const deps eqnss = - assemble_preprocessed_eqnss ctxt idx_of_const deps (map preprocess_eqns eqnss); +fun assemble_eqnss ctxt idx_of_sym deps eqnss = + assemble_preprocessed_eqnss ctxt idx_of_sym deps (map preprocess_eqns eqnss); (* compilation of equations *) @@ -432,26 +438,26 @@ let val (deps, deps_vals) = split_list (map_filter (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps); - val idx_of_const = raw_deps + val idx_of_sym = raw_deps |> map (fn dep => (dep, snd (Code_Symbol.Graph.get_node nbe_program dep))) |> AList.lookup (op =) |> (fn f => the o f); - val s = assemble_eqnss ctxt idx_of_const deps eqnss; + val s = assemble_eqnss ctxt idx_of_sym deps eqnss; val syms = map fst eqnss; in s |> traced ctxt (fn s => "\n--- code to be evaluated:\n" ^ s) |> pair "" |> Code_Runtime.value ctxt univs_cookie - |> (fn f => f deps_vals) - |> (fn poly_univs => syms ~~ poly_univs) + |> (fn dependent_fs => dependent_fs deps_vals) + |> (fn fs => syms ~~ fs) end; (* extraction of equations from statements *) -fun dummy_const sym tys dictss = - IConst { sym = sym, typargs = tys, dictss = dictss, +fun dummy_const sym typargs dictss = + IConst { sym = sym, typargs = typargs, dictss = dictss, dom = [], annotation = NONE, range = ITyVar "" }; fun eqns_of_stmt (_, Code_Thingol.NoStmt) = @@ -486,11 +492,11 @@ (* compilation of whole programs *) -fun ensure_const_idx name (nbe_program, (maxidx, const_tab)) = - if can (Code_Symbol.Graph.get_node nbe_program) name - then (nbe_program, (maxidx, const_tab)) - else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program, - (maxidx + 1, Inttab.update_new (maxidx, name) const_tab)); +fun ensure_sym_idx sym (nbe_program, (max_idx, sym_tab)) = + if can (Code_Symbol.Graph.get_node nbe_program) sym + then (nbe_program, (max_idx, sym_tab)) + else (Code_Symbol.Graph.new_node (sym, (NONE, max_idx)) nbe_program, + (max_idx + 1, Inttab.update_new (max_idx, sym) sym_tab)); fun compile_stmts ctxt stmts_deps = let @@ -505,7 +511,7 @@ |> compile_eqnss ctxt nbe_program refl_deps |> rpair nbe_program; in - fold ensure_const_idx refl_deps + fold ensure_sym_idx refl_deps #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps #> compile #-> fold (fn (sym, univ) => (Code_Symbol.Graph.map_node sym o apfst) (K (SOME univ)))) @@ -513,10 +519,10 @@ fun compile_program { ctxt, program } = let - fun add_stmts names (nbe_program, (maxidx, const_tab)) = + fun add_stmts names (nbe_program, (max_idx, sym_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names - then (nbe_program, (maxidx, const_tab)) - else (nbe_program, (maxidx, const_tab)) + then (nbe_program, (max_idx, sym_tab)) + else (nbe_program, (max_idx, sym_tab)) |> compile_stmts ctxt (map (fn sym => ((sym, Code_Symbol.Graph.get_node program sym), Code_Symbol.Graph.immediate_succs program sym)) names); in @@ -542,37 +548,37 @@ |> (fn f => apps (f tparams) (rev dict_frees)) end; -fun reconstruct_term ctxt const_tab tfrees t = +fun reconstruct_term ctxt sym_tab tfrees = let fun take_until f [] = [] | take_until f (x :: xs) = if f x then [] else x :: take_until f xs; fun is_dict (Const ((idx, _), _)) = - (case Inttab.lookup const_tab idx of + (case Inttab.lookup sym_tab idx of SOME (Constant _) => false | _ => true) | is_dict (DFree _) = true | is_dict _ = false; fun const_of_idx idx = - case Inttab.lookup const_tab idx of SOME (Constant const) => const; + case Inttab.lookup sym_tab idx of SOME (Constant const) => const; fun reconstruct_type (Type (tyco, tys)) = Term.Type (tyco, map reconstruct_type tys) | reconstruct_type (TParam v) = TFree (v, the (AList.lookup (op =) tfrees v)); fun of_apps bounds (t, ts) = list_comb (t, rev (map (of_univ bounds) ts)) - and of_univ bounds (Const ((idx, tys), ts)) = + and of_univ bounds (Const ((idx, tparams), us)) = let val const = const_of_idx idx; - val ts' = take_until is_dict ts; - val T = Consts.instance (Proof_Context.consts_of ctxt) (const, map reconstruct_type tys); - in of_apps bounds (Term.Const (const, T), ts') end - | of_univ bounds (BVar (n, ts)) = - of_apps bounds (Bound (bounds - n - 1), ts) - | of_univ bounds (t as Abs _) = - Term.Abs ("u", dummyT, of_univ (bounds + 1) (apps t [BVar (bounds, [])])) - in of_univ 0 t end; + val us' = take_until is_dict us; + val T = Consts.instance (Proof_Context.consts_of ctxt) (const, map reconstruct_type tparams); + in of_apps bounds (Term.Const (const, T), us') end + | of_univ bounds (BVar (i, us)) = + of_apps bounds (Bound (bounds - i - 1), us) + | of_univ bounds (u as Abs _) = + Term.Abs ("u", dummyT, of_univ (bounds + 1) (apps u [BVar (bounds, [])])) + in of_univ 0 end; -fun compile_and_reconstruct_term { ctxt, nbe_program, const_tab, deps, tfrees, vs_ty_t } = +fun compile_and_reconstruct_term { ctxt, nbe_program, sym_tab, deps, tfrees, vs_ty_t } = compile_term { ctxt = ctxt, nbe_program = nbe_program, deps = deps, tfrees = tfrees, vs_ty_t = vs_ty_t } - |> reconstruct_term ctxt const_tab tfrees; + |> reconstruct_term ctxt sym_tab tfrees; fun retype_term ctxt t T = let @@ -585,7 +591,7 @@ singleton (Variable.export_terms ctxt' ctxt') (Syntax.check_term ctxt' (Type.constraint T t)) end; -fun normalize_term (nbe_program, const_tab) raw_ctxt t_original vs_ty_t deps = +fun normalize_term (nbe_program, sym_tab) raw_ctxt t_original vs_ty_t deps = let val T = fastype_of t_original; val tfrees = Term.add_tfrees t_original []; @@ -601,7 +607,7 @@ else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t'); in Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term - { ctxt = ctxt, nbe_program = nbe_program, const_tab = const_tab, deps = deps, + { ctxt = ctxt, nbe_program = nbe_program, sym_tab = sym_tab, deps = deps, tfrees = tfrees, vs_ty_t = vs_ty_t } |> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t) |> retype @@ -621,11 +627,11 @@ fun compile ignore_cache ctxt program = let - val (nbe_program, (_, const_tab)) = + val (nbe_program, (_, sym_tab)) = Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt)) (Code_Preproc.timed "compiling NBE program" #ctxt compile_program { ctxt = ctxt, program = program }); - in (nbe_program, const_tab) end; + in (nbe_program, sym_tab) end; (* evaluation oracle *) @@ -639,11 +645,11 @@ val (_, raw_oracle) = Theory.setup_result (Thm.add_oracle (\<^binding>\normalization_by_evaluation\, - fn (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct) => - mk_equals ctxt ct (normalize_term nbe_program_const_tab ctxt (Thm.term_of ct) vs_ty_t deps))); + fn (nbe_program_sym_tab, ctxt, vs_ty_t, deps, ct) => + mk_equals ctxt ct (normalize_term nbe_program_sym_tab ctxt (Thm.term_of ct) vs_ty_t deps))); -fun oracle nbe_program_const_tab ctxt vs_ty_t deps ct = - raw_oracle (nbe_program_const_tab, ctxt, vs_ty_t, deps, ct); +fun oracle nbe_program_sym_tab ctxt vs_ty_t deps ct = + raw_oracle (nbe_program_sym_tab, ctxt, vs_ty_t, deps, ct); fun dynamic_conv ctxt = lift_triv_classes_conv ctxt (fn ctxt' => Code_Thingol.dynamic_conv ctxt' (fn program =>