restored type reconstruction for nbe: remaining type variables must remain schematic for checking
authorhaftmann
Fri, 04 Apr 2025 23:12:19 +0200
changeset 82443 3e92066d2be7
parent 82442 6d0bb3887397
child 82444 7a9164068583
restored type reconstruction for nbe: remaining type variables must remain schematic for checking
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Fri Apr 04 23:12:18 2025 +0200
+++ b/src/Tools/nbe.ML	Fri Apr 04 23:12:19 2025 +0200
@@ -546,16 +546,27 @@
     { ctxt = ctxt, nbe_program = nbe_program, deps = deps, term = term }
   |> reconstruct_term ctxt const_tab;
 
-fun normalize_term (nbe_program, const_tab) raw_ctxt t_original ((vs, ty) : typscheme, t) deps =
+fun retype_term ctxt t T =
+  let
+    val ctxt' =
+      ctxt
+      |> Variable.declare_typ T
+      |> Config.put Type_Infer.object_logic false
+      |> Config.put Type_Infer_Context.const_sorts false
+  in
+    singleton (Variable.export_terms ctxt' ctxt') (Syntax.check_term ctxt' (Type.constraint T t))
+  end;
+
+fun normalize_term (nbe_program, const_tab) raw_ctxt t_original ((vs, _) : typscheme, t) deps =
   let
     val ctxt = Syntax.init_pretty_global (Proof_Context.theory_of raw_ctxt);
-    val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
-    fun type_infer t' =
-      Syntax.check_term
-        (ctxt
-          |> Config.put Type_Infer.object_logic false
-          |> Config.put Type_Infer_Context.const_sorts false)
-        (Type.constraint (fastype_of t_original) t');
+    val string_of_term =
+      Syntax.string_of_term
+         (ctxt
+           |> Config.put show_types true
+           |> Config.put show_sorts true);
+    fun retype t' =
+      retype_term ctxt t' (fastype_of t_original);
     fun check_tvars t' =
       if null (Term.add_tvars t' []) then t'
       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t');
@@ -563,7 +574,7 @@
     Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term
       { ctxt = ctxt, nbe_program = nbe_program, const_tab = const_tab, deps = deps, term = (vs, t) }
     |> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t)
-    |> type_infer
+    |> retype
     |> traced ctxt (fn t => "Types inferred:\n" ^ string_of_term t)
     |> check_tvars
     |> traced ctxt (fn _ => "---\n")