src/Tools/nbe.ML
changeset 82375 1972ae7da0d2
parent 82374 2d0721461810
child 82442 6d0bb3887397
--- a/src/Tools/nbe.ML	Fri Mar 28 14:13:37 2025 +0100
+++ b/src/Tools/nbe.ML	Fri Mar 28 14:13:38 2025 +0100
@@ -280,15 +280,37 @@
 
 (* code generation *)
 
-fun assemble_eqnss ctxt idx_of_const deps eqnss =
+fun subst_nonlin_vars args =
   let
-    fun prep_eqns (sym, (vs, eqns)) =
-      let
-        val dicts = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
-        val num_args = length dicts + ((length o fst o hd) eqns);
-      in (sym, (num_args, (dicts, eqns))) end;
-    val eqnss' = map prep_eqns eqnss;
+    val vs = (fold o Code_Thingol.fold_varnames)
+      (fn v => AList.map_default (op =) (v, 0) (Integer.add 1)) args [];
+    val names = Name.make_context (map fst vs);
+    val (vs_renames, _) = fold_map (fn (v, k) => if k > 1
+      then Name.invent' v (k - 1) #>> (fn vs => (v, vs))
+      else pair (v, [])) vs names;
+    val samepairs = maps (fn (v, vs) => map (pair v) vs) vs_renames;
+    fun subst_vars (t as IConst _) samepairs = (t, samepairs)
+      | subst_vars (t as IVar NONE) samepairs = (t, samepairs)
+      | subst_vars (t as IVar (SOME v)) samepairs = (case AList.lookup (op =) samepairs v
+         of SOME v' => (IVar (SOME v'), AList.delete (op =) v samepairs)
+          | NONE => (t, samepairs))
+      | subst_vars (t1 `$ t2) samepairs = samepairs
+          |> subst_vars t1
+          ||>> subst_vars t2
+          |>> (op `$)
+      | subst_vars (ICase { primitive = t, ... }) samepairs = subst_vars t samepairs;
+    val (args', _) = fold_map subst_vars args samepairs;
+  in (samepairs, args') end;
 
+fun preprocess_eqns (sym, (vs, eqns)) =
+  let
+    val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
+    val num_args = length dict_params + ((length o fst o hd) eqns);
+    val default_params = map nbe_default (Name.invent_global "a" (num_args - length dict_params));
+  in (sym, (num_args, (dict_params, (map o apfst) subst_nonlin_vars eqns, default_params))) end;
+
+fun assemble_preprocessed_eqnss ctxt idx_of_const deps eqnss =
+  let
     fun fun_ident 0 (Code_Symbol.Constant "") = "nbe_value"
       | fun_ident i sym = "c_" ^ string_of_int (idx_of_const sym)
           ^ "_" ^ Code_Symbol.default_base sym ^ "_" ^ string_of_int i;
@@ -300,10 +322,10 @@
     fun apply_local i sym = nbe_apps_local (fun_ident i sym);
     fun apply_constr sym = nbe_apps_constr (constr_fun_ident sym);
 
-    fun assemble_constapp sym dss ts = 
+    fun assemble_constapp sym dicts ts = 
       let
-        val ts' = (maps o map) assemble_dict dss @ ts;
-      in case AList.lookup (op =) eqnss' sym
+        val ts' = (maps o map) assemble_dict dicts @ ts;
+      in case AList.lookup (op =) eqnss sym
        of SOME (num_args, _) => if num_args <= length ts'
             then let val (ts1, ts2) = chop num_args ts'
             in nbe_apps (apply_local 0 sym ts1) ts2
@@ -316,82 +338,64 @@
       fold_rev (fn classrel => assemble_constapp (Class_Relation classrel) [] o single) classrels
     and assemble_dict (Dict (classrels, x)) =
           assemble_classrels classrels (assemble_plain_dict x)
-    and assemble_plain_dict (Dict_Const (inst, dss)) =
-          assemble_constapp (Class_Instance inst) (map snd dss) []
+    and assemble_plain_dict (Dict_Const (inst, dicts)) =
+          assemble_constapp (Class_Instance inst) (map snd dicts) []
       | assemble_plain_dict (Dict_Var { var, index, ... }) =
           nbe_dict var index
 
+    fun assemble_constmatch sym dicts ts =
+      apply_constr sym ((maps o map) (K "_") dicts @ ts);
+
     fun assemble_iterm constapp =
       let
-        fun of_iterm match_cont t =
+        fun of_iterm match_continuation t =
           let
             val (t', ts) = Code_Thingol.unfold_app t
-          in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
-        and of_iapp match_cont (IConst { sym, dicts = dss, ... }) ts = constapp sym dss ts
-          | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound_optional v) ts
-          | of_iapp match_cont ((v, _) `|=> (t, _)) ts =
+          in of_iapp match_continuation t' (fold_rev (cons o of_iterm NONE) ts []) end
+        and of_iapp match_continuation (IConst { sym, dicts, ... }) ts = constapp sym dicts ts
+          | of_iapp match_continuation (IVar v) ts = nbe_apps (nbe_bound_optional v) ts
+          | of_iapp match_continuation ((v, _) `|=> (t, _)) ts =
               nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound_optional v]) (of_iterm NONE t))) ts
-          | of_iapp match_cont (ICase { term = t, clauses = clauses, primitive = t0, ... }) ts =
+          | of_iapp match_continuation (ICase { term = t, clauses = clauses, primitive = t0, ... }) ts =
               nbe_apps (ml_cases (of_iterm NONE t)
-                (map (fn (p, t) => (of_iterm NONE p, of_iterm match_cont t)) clauses
-                  @ [("_", case match_cont of SOME s => s | NONE => of_iterm NONE t0)])) ts
+                (map (fn (p, t) => (of_iterm NONE p, of_iterm match_continuation t)) clauses
+                  @ [("_", case match_continuation of SOME s => s | NONE => of_iterm NONE t0)])) ts
       in of_iterm end;
 
-    fun subst_nonlin_vars args =
+    val assemble_args = map (assemble_iterm assemble_constmatch NONE);
+    val assemble_rhs = assemble_iterm assemble_constapp;
+
+    fun assemble_eqn sym dict_params default_params (i, ((samepairs, args), rhs)) =
       let
-        val vs = (fold o Code_Thingol.fold_varnames)
-          (fn v => AList.map_default (op =) (v, 0) (Integer.add 1)) args [];
-        val names = Name.make_context (map fst vs);
-        val (vs_renames, _) = fold_map (fn (v, k) => if k > 1
-          then Name.invent' v (k - 1) #>> (fn vs => (v, vs))
-          else pair (v, [])) vs names;
-        val samepairs = maps (fn (v, vs) => map (pair v) vs) vs_renames;
-        fun subst_vars (t as IConst _) samepairs = (t, samepairs)
-          | subst_vars (t as IVar NONE) samepairs = (t, samepairs)
-          | subst_vars (t as IVar (SOME v)) samepairs = (case AList.lookup (op =) samepairs v
-             of SOME v' => (IVar (SOME v'), AList.delete (op =) v samepairs)
-              | NONE => (t, samepairs))
-          | subst_vars (t1 `$ t2) samepairs = samepairs
-              |> subst_vars t1
-              ||>> subst_vars t2
-              |>> (op `$)
-          | subst_vars (ICase { primitive = t, ... }) samepairs = subst_vars t samepairs;
-        val (args', _) = fold_map subst_vars args samepairs;
-      in (samepairs, args') end;
-
-    fun assemble_eqn sym dicts default_args (i, (args, rhs)) =
-      let
-        val match_cont = if Code_Symbol.is_value sym then NONE
-          else SOME (apply_local (i + 1) sym (dicts @ default_args));
-        val assemble_arg = assemble_iterm
-          (fn sym' => fn dss => fn ts => apply_constr sym' ((maps o map) (K "_")
-            dss @ ts)) NONE;
-        val assemble_rhs = assemble_iterm assemble_constapp match_cont;
-        val (samepairs, args') = subst_nonlin_vars args;
-        val s_args = map assemble_arg args';
-        val s_rhs = if null samepairs then assemble_rhs rhs
+        val default_rhs = apply_local (i + 1) sym (dict_params @ default_params);
+        val s_args = assemble_args args;
+        val s_rhs = if null samepairs then assemble_rhs (SOME default_rhs) rhs
           else ml_if (ml_and (map nbe_same samepairs))
-            (assemble_rhs rhs) (the match_cont);
-        val eqns = case match_cont
-         of NONE => [([ml_list (rev (dicts @ s_args))], s_rhs)]
-          | SOME default_rhs =>
-              [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
-                ([ml_list (rev (dicts @ default_args))], default_rhs)]
+            (assemble_rhs (SOME default_rhs) rhs) default_rhs;
+        val eqns = [([ml_list (rev (dict_params @ map2 ml_as default_params s_args))], s_rhs),
+          ([ml_list (rev (dict_params @ default_params))], default_rhs)]
       in (fun_ident i sym, eqns) end;
 
-    fun assemble_eqns (sym, (num_args, (dicts, eqns))) =
-      let
-        val default_args = map nbe_default (Name.invent_global "a" (num_args - length dicts));
-        val eqns' = map_index (assemble_eqn sym dicts default_args) eqns
-          @ (if Code_Symbol.is_value sym then [] else [(fun_ident (length eqns) sym,
-            [([ml_list (rev (dicts @ default_args))],
-              apply_constr sym (dicts @ default_args))])]);
-      in (eqns', nbe_abss num_args (fun_ident 0 sym)) end;
+    fun assemble_default_eqn sym dict_params default_params i =
+      (fun_ident i sym,
+        [([ml_list (rev (dict_params @ default_params))], apply_constr sym (dict_params @ default_params))]);
+
+    fun assemble_value_equation sym dict_params (([], args), rhs) =
+      (fun_ident 0 sym, [([ml_list (rev (dict_params @ assemble_args args))], assemble_rhs NONE rhs)]);
 
-    val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
+    fun assemble_eqns (sym, (num_args, (dict_params, eqns, default_params))) =
+      (if Code_Symbol.is_value sym then [assemble_value_equation sym dict_params (the_single eqns)]
+      else map_index (assemble_eqn sym dict_params default_params) eqns
+        @ [assemble_default_eqn sym dict_params default_params (length eqns)],
+      nbe_abss num_args (fun_ident 0 sym));
+
+    val (fun_vars, fun_vals) = map_split assemble_eqns eqnss;
     val deps_vars = ml_list (map (fun_ident 0) deps);
   in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
 
+fun assemble_eqnss ctxt idx_of_const deps eqnss =
+  assemble_preprocessed_eqnss ctxt idx_of_const deps (map preprocess_eqns eqnss);
+
 
 (* compilation of equations *)
 
@@ -418,8 +422,8 @@
 
 (* extraction of equations from statements *)
 
-fun dummy_const sym dss =
-  IConst { sym = sym, typargs = [], dicts = dss,
+fun dummy_const sym dicts =
+  IConst { sym = sym, typargs = [], dicts = dicts,
     dom = [], annotation = NONE, range = ITyVar "" };
 
 fun eqns_of_stmt (_, Code_Thingol.NoStmt) =
@@ -447,7 +451,7 @@
       []
   | eqns_of_stmt (sym_inst, Code_Thingol.Classinst { class, tyco, vs, superinsts, inst_params, ... }) =
       [(sym_inst, (vs, [([], dummy_const (Type_Class class) [] `$$
-        map (fn (class, dss) => dummy_const (Class_Instance (tyco, class)) (map snd dss)) superinsts
+        map (fn (class, dicts) => dummy_const (Class_Instance (tyco, class)) (map snd dicts)) superinsts
         @ map (IConst o fst o snd o fst) inst_params)]))];
 
 
@@ -480,7 +484,8 @@
 
 fun compile_program { ctxt, program } =
   let
-    fun add_stmts names (nbe_program, (maxidx, const_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
+    fun add_stmts names (nbe_program, (maxidx, const_tab)) =
+      if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
       then (nbe_program, (maxidx, const_tab))
       else (nbe_program, (maxidx, const_tab))
         |> compile_stmts ctxt (map (fn sym => ((sym, Code_Symbol.Graph.get_node program sym),
@@ -521,8 +526,8 @@
       #>> (fn ts' => list_comb (t, rev ts'))
     and of_univ bounds (Const (idx, ts)) typidx =
           let
+            val const = const_of_idx idx;
             val ts' = take_until is_dict ts;
-            val const = const_of_idx idx;
             val T = map_type_tvar (fn ((v, i), _) =>
               Type_Infer.param typidx (v ^ string_of_int i, []))
                 (Sign.the_const_type (Proof_Context.theory_of ctxt) const);