src/HOL/Import/import_rule.ML
changeset 81854 2a5cbd329241
parent 81853 f06281e21df9
child 81855 a001d14f150c
--- a/src/HOL/Import/import_rule.ML	Fri Jan 17 13:44:45 2025 +0100
+++ b/src/HOL/Import/import_rule.ML	Fri Jan 17 14:31:48 2025 +0100
@@ -31,10 +31,6 @@
 structure Import_Rule: IMPORT_RULE =
 struct
 
-type state = (ctyp Inttab.table * int) * (cterm Inttab.table * int) * (thm Inttab.table * int)
-
-val init_state: state = ((Inttab.empty, 0), (Inttab.empty, 0), (Inttab.empty, 0))
-
 fun implies_elim_all th = implies_elim_list th (map Thm.assume (cprems_of th))
 
 fun meta_mp th1 th2 =
@@ -308,23 +304,39 @@
       | NONE => Sign.full_bname thy (make_name c))
   in Const (d, ty) end
 
-fun get (map, no) s =
-  case Int.fromString s of
+
+datatype state =
+  State of theory * (ctyp Inttab.table * int) * (cterm Inttab.table * int) * (thm Inttab.table * int)
+
+fun init_state thy = State (thy, (Inttab.empty, 0), (Inttab.empty, 0), (Inttab.empty, 0))
+
+fun get (tab, no) s =
+  (case Int.fromString s of
     NONE => error "Import_Rule.get: not a number"
-  | SOME i => (case Inttab.lookup map (Int.abs i) of
-      NONE => error "Import_Rule.get: lookup failed"
-    | SOME res => (res, (if i < 0 then Inttab.delete (Int.abs i) map else map, no)))
+  | SOME i =>
+      (case Inttab.lookup tab (Int.abs i) of
+        NONE => error "Import_Rule.get: lookup failed"
+      | SOME res => (res, (if i < 0 then Inttab.delete (Int.abs i) tab else tab, no))))
 
-fun typ i (thy, (tyi, tmi, thi)) = let val (i, tyi) = (get tyi i) in (i, (thy, (tyi, tmi, thi))) end
-fun term i (thy, (tyi, tmi, thi)) = let val (i, tmi) = (get tmi i) in (i, (thy, (tyi, tmi, thi))) end
-fun thm i (thy, (tyi, tmi, thi)) = let val (i, thi) = (get thi i) in (i, (thy, (tyi, tmi, thi))) end
-fun set (map, no) v = (Inttab.update_new (no + 1, v) map, no + 1)
-fun set_typ v (thy, (tyi, tmi, thi)) = (thy, (set tyi v, tmi, thi))
-fun set_term v (thy, (tyi, tmi, thi)) = (thy, (tyi, set tmi v, thi))
-fun set_thm v (thy, (tyi, tmi, thi)) = (thy, (tyi, tmi, set thi v))
+fun get_theory (State (thy, _, _, _)) = thy;
+fun map_theory f (State (thy, a, b, c)) = State (f thy, a, b, c);
+fun map_theory_result f (State (thy, a, b, c)) =
+  let val (res, thy') = f thy in (res, State (thy', a, b, c)) end;
+
+fun ctyp_of (State (thy, _, _, _)) = Thm.global_ctyp_of thy;
+fun cterm_of (State (thy, _, _, _)) = Thm.global_cterm_of thy;
 
-fun last_thm (_, _, (map, no)) =
-  case Inttab.lookup map no of
+fun typ i (State (thy, a, b, c)) = let val (i, a') = get a i in (i, State (thy, a', b, c)) end
+fun term i (State (thy, a, b, c)) = let val (i, b') = get b i in (i, State (thy, a, b', c)) end
+fun thm i (State (thy, a, b, c)) = let val (i, c') = get c i in (i, State (thy, a, b, c')) end
+
+fun set (tab, no) v = (Inttab.update_new (no + 1, v) tab, no + 1)
+fun set_typ ty (State (thy, a, b, c)) = State (thy, set a ty, b, c)
+fun set_term tm (State (thy, a, b, c)) = State (thy, a, set b tm, c)
+fun set_thm th (State (thy, a, b, c)) = State (thy, a, b, set c th)
+
+fun last_thm (State (_, _, _, (tab, no))) =
+  case Inttab.lookup tab no of
     NONE => error "Import_Rule.last_thm: lookup failed"
   | SOME th => th
 
@@ -371,61 +383,63 @@
       | process (#"E", [th1, th2]) = thm th1 ##>> thm th2 #>> uncurry eq_mp #-> set_thm
       | process (#"D", [th1, th2]) = thm th1 ##>> thm th2 #>> uncurry deduct #-> set_thm
       | process (#"L", [t, th]) = term t ##>> (fn ti => thm th ti) #>> uncurry abs #-> set_thm
-      | process (#"M", [s]) = (fn (thy, state) =>
+      | process (#"M", [s]) = (fn state =>
           let
+            val thy = get_theory state
             val ctxt = Proof_Context.init_global thy
             val th = freezeT thy (Global_Theory.get_thm thy s)
             val ((_, [th']), _) = Variable.import true [th] ctxt
           in
-            set_thm th' (thy, state)
+            set_thm th' state
           end)
-      | process (#"Q", l) = (fn (thy, state) =>
+      | process (#"Q", l) = (fn state =>
           let
             val (tys, th) = list_last l
-            val (th, tstate) = thm th (thy, state)
-            val (tys, tstate) = fold_map typ tys tstate
+            val (th, state1) = thm th state
+            val (tys, state2) = fold_map typ tys state1
           in
-            set_thm (inst_type thy (pair_list tys) th) tstate
+            set_thm (inst_type (get_theory state) (pair_list tys) th) state2
           end)
-      | process (#"S", l) = (fn tstate =>
+      | process (#"S", l) = (fn state =>
           let
             val (tms, th) = list_last l
-            val (th, tstate) = thm th tstate
-            val (tms, tstate) = fold_map term tms tstate
+            val (th, state1) = thm th state
+            val (tms, state2) = fold_map term tms state1
           in
-            set_thm (inst (pair_list tms) th) tstate
+            set_thm (inst (pair_list tms) th) state2
           end)
-      | process (#"F", [name, t]) = (fn tstate =>
+      | process (#"F", [name, t]) = (fn state =>
           let
-            val (tm, (thy, state)) = term t tstate
-            val (th, thy) = def (make_name name) tm thy
+            val (tm, state1) = term t state
+            val (th, state2) = map_theory_result (def (make_name name) tm) state1
           in
-            set_thm th (thy, state)
+            set_thm th state2
           end)
-      | process (#"F", [name]) = (fn (thy, state) => set_thm (mdef thy name) (thy, state))
-      | process (#"Y", [name, absname, repname, t1, t2, th]) = (fn tstate =>
+      | process (#"F", [name]) = (fn state => set_thm (mdef (get_theory state) name) state)
+      | process (#"Y", [name, absname, repname, t1, t2, th]) = (fn state =>
           let
-            val (th, tstate) = thm th tstate
-            val (t1, tstate) = term t1 tstate
-            val (t2, (thy, state)) = term t2 tstate
-            val (th, thy) = tydef name absname repname t1 t2 th thy
+            val (th, state1) = thm th state
+            val (t1, state2) = term t1 state1
+            val (t2, state3) = term t2 state2
+            val (th, state4) = map_theory_result (tydef name absname repname t1 t2 th) state3
           in
-            set_thm th (thy, state)
+            set_thm th state4
           end)
-      | process (#"Y", [name, _, _]) = (fn (thy, state) => set_thm (mtydef thy name) (thy, state))
-      | process (#"t", [n]) = (fn (thy, state) =>
-          set_typ (Thm.global_ctyp_of thy (make_tfree n)) (thy, state))
-      | process (#"a", n :: l) = (fn (thy, state) =>
-          fold_map typ l (thy, state) |>>
-            (fn tys => Thm.global_ctyp_of thy (make_type thy (n, map Thm.typ_of tys))) |-> set_typ)
-      | process (#"v", [n, ty]) = (fn (thy, state) =>
-          typ ty (thy, state) |>> (fn ty => Thm.global_cterm_of thy (make_free (n, Thm.typ_of ty))) |-> set_term)
-      | process (#"c", [n, ty]) = (fn (thy, state) =>
-          typ ty (thy, state) |>> (fn ty => Thm.global_cterm_of thy (make_const thy (n, Thm.typ_of ty))) |-> set_term)
+      | process (#"Y", [name, _, _]) = (fn state => set_thm (mtydef (get_theory state) name) state)
+      | process (#"t", [n]) = (fn state => set_typ (ctyp_of state (make_tfree n)) state)
+      | process (#"a", n :: l) = (fn state =>
+          fold_map typ l state
+          |>> (fn tys => ctyp_of state (make_type (get_theory state) (n, map Thm.typ_of tys)))
+          |-> set_typ)
+      | process (#"v", [n, ty]) = (fn state =>
+          typ ty state |>> (fn ty => cterm_of state (make_free (n, Thm.typ_of ty))) |-> set_term)
+      | process (#"c", [n, ty]) = (fn state =>
+          typ ty state |>> (fn ty =>
+            cterm_of state (make_const (get_theory state) (n, Thm.typ_of ty))) |-> set_term)
       | process (#"f", [t1, t2]) = term t1 ##>> term t2 #>> uncurry Thm.apply #-> set_term
       | process (#"l", [t1, t2]) = term t1 ##>> term t2 #>> uncurry Thm.lambda #-> set_term
-      | process (#"+", [s]) = (fn (thy, state) =>
-          (store_thm (Binding.name (make_name s)) (last_thm state) thy, state))
+      | process (#"+", [s]) = (fn state =>
+          map_theory (store_thm (Binding.name (make_name s)) (last_thm state)) state)
       | process (c, _) = error ("process: unknown command: " ^ String.implode [c])
   in
     process (parse_line str)
@@ -437,7 +451,7 @@
     val lines =
       if Path.is_zst path then Bytes.read path |> Zstd.uncompress |> Bytes.trim_split_lines
       else File.read_lines path
-  in #1 (fold process_line lines (thy, init_state)) end
+  in get_theory (fold process_line lines (init_state thy)) end
 
 val _ =
   Outer_Syntax.command \<^command_keyword>\<open>import_file\<close> "import recorded proofs from HOL Light"