src/Pure/context.ML
changeset 77895 655bd3b0671b
parent 77894 186bd4012b78
child 77897 ff924ce0c599
--- a/src/Pure/context.ML	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/Pure/context.ML	Thu Apr 20 21:26:35 2023 +0200
@@ -51,7 +51,7 @@
   val subthy: theory * theory -> bool
   val trace_theories: bool Unsynchronized.ref
   val theories_trace: unit -> {active_positions: Position.T list, active: int, total: int}
-  val join_thys: theory * theory -> theory
+  val join_thys: theory list -> theory
   val begin_thy: string -> theory list -> theory
   val finish_thy: theory -> theory
   val theory_data_sizeof1: theory -> (Position.T * int) list
@@ -95,7 +95,7 @@
   include CONTEXT
   structure Theory_Data:
   sig
-    val declare: Position.T -> Any.T -> (theory * theory -> Any.T * Any.T -> Any.T) -> serial
+    val declare: Position.T -> Any.T -> ((theory * Any.T) list -> Any.T) -> serial
     val get: serial -> (Any.T -> 'a) -> theory -> 'a
     val put: serial -> ('a -> Any.T) -> 'a -> theory -> theory
   end
@@ -118,6 +118,19 @@
 
 (** datatype theory **)
 
+(* implicit state *)
+
+type state = {stage: int} Synchronized.var;
+
+fun make_state () : state =
+  Synchronized.var "Context.state" {stage = 0};
+
+fun next_stage (state: state) =
+  Synchronized.change_result state (fn {stage} => (stage + 1, {stage = stage + 1}));
+
+
+(* theory_id *)
+
 abstype theory_id =
   Theory_Id of
    {id: serial,                   (*identifier*)
@@ -131,13 +144,16 @@
 
 end;
 
+
+(* theory *)
+
 datatype theory =
   Theory of
+   (*allocation state*)
+   state *
    (*identity*)
    {theory_id: theory_id,
     token: Position.T Unsynchronized.ref} *
-   (*allocation state*)
-   {next_stage: unit -> int} *
    (*ancestry*)
    {parents: theory list,         (*immediate predecessors*)
     ancestors: theory list} *     (*all predecessors -- canonical reverse order*)
@@ -148,16 +164,13 @@
 
 fun rep_theory (Theory args) = args;
 
-val theory_identity = #1 o rep_theory;
+val state_of = #1 o rep_theory;
+val theory_identity = #2 o rep_theory;
 val theory_id = #theory_id o theory_identity;
 val identity_of = rep_theory_id o theory_id;
-val state_of = #2 o rep_theory;
 val ancestry_of = #3 o rep_theory;
 val data_of = #4 o rep_theory;
 
-fun make_state () = {next_stage = Counter.make ()};
-fun next_stage {next_stage: unit -> int} = next_stage ();
-
 fun make_ancestry parents ancestors = {parents = parents, ancestors = ancestors};
 
 fun stage_final stage = stage = 0;
@@ -214,17 +227,11 @@
   else error ("Unfinished theory " ^ quote name);
 
 
-(* build ids *)
+(* identity *)
 
-val merge_ids =
-  apply2 identity_of #>
-  (fn ({id = id1, ids = ids1, ...}, {id = id2, ids = ids2, ...}) =>
-    Intset.merge (ids1, ids2)
-    |> Intset.insert id1
-    |> Intset.insert id2);
-
-
-(* equality and inclusion *)
+fun merge_ids thys =
+  fold (identity_of #> (fn {id, ids, ...} => fn acc => Intset.merge (acc, ids) |> Intset.insert id))
+    thys Intset.empty;
 
 val eq_thy_id = op = o apply2 (#id o rep_theory_id);
 val eq_thy = op = o apply2 (#id o identity_of);
@@ -250,6 +257,11 @@
 
 val merge_ancestors = merge eq_thy_consistent;
 
+val eq_ancestry =
+  apply2 ancestry_of #>
+    (fn ({parents, ancestors}, {parents = parents', ancestors = ancestors'}) =>
+      eq_list eq_thy (parents, parents') andalso eq_list eq_thy (ancestors, ancestors'));
+
 
 
 (** theory data **)
@@ -263,24 +275,25 @@
 type kind =
  {pos: Position.T,
   empty: Any.T,
-  merge: theory * theory -> Any.T * Any.T -> Any.T};
+  merge: (theory * Any.T) list -> Any.T};
 
 val kinds = Synchronized.var "Theory_Data" (Datatab.empty: kind Datatab.table);
 
-fun invoke name f k x =
+fun the_kind k =
   (case Datatab.lookup (Synchronized.value kinds) k of
-    SOME kind =>
-      if ! timing andalso name <> "" then
-        Timing.cond_timeit true ("Theory_Data." ^ name ^ Position.here (#pos kind))
-          (fn () => f kind x)
-      else f kind x
+    SOME kind => kind
   | NONE => raise Fail "Invalid theory data identifier");
 
 in
 
-fun invoke_pos k = invoke "" (K o #pos) k ();
-fun invoke_empty k = invoke "" (K o #empty) k ();
-fun invoke_merge thys = invoke "merge" (fn kind => #merge kind thys);
+val invoke_pos = #pos o the_kind;
+val invoke_empty = #empty o the_kind;
+
+fun invoke_merge kind args =
+  if ! timing then
+    Timing.cond_timeit true ("Theory_Data.merge" ^ Position.here (#pos kind))
+      (fn () => #merge kind args)
+  else #merge kind args;
 
 fun declare_data pos empty merge =
   let
@@ -289,12 +302,23 @@
     val _ = Synchronized.change kinds (Datatab.update (k, kind));
   in k end;
 
+fun lookup_data k thy = Datatab.lookup (data_of thy) k;
+
 fun get_data k thy =
-  (case Datatab.lookup (data_of thy) k of
+  (case lookup_data k thy of
     SOME x => x
   | NONE => invoke_empty k);
 
-fun merge_data thys = Datatab.join (invoke_merge thys);
+fun merge_data [] = Datatab.empty
+  | merge_data [thy] = data_of thy
+  | merge_data thys =
+      let
+        fun merge (k, kind) data =
+          (case map_filter (fn thy => lookup_data k thy |> Option.map (pair thy)) thys of
+            [] => data
+          | [(_, x)] => Datatab.default (k, x) data
+          | args => Datatab.update (k, invoke_merge kind args) data);
+      in Datatab.fold merge (Synchronized.value kinds) (data_of (hd thys)) end;
 
 end;
 
@@ -336,11 +360,11 @@
      total = length trace}
   end;
 
-fun create_thy ids name stage state ancestry data =
+fun create_thy state ids name stage ancestry data =
   let
     val theory_id = make_theory_id {id = serial (), ids = ids, name = name, stage = stage};
-    val token = make_token ();
-  in Theory ({theory_id = theory_id, token = token}, state, ancestry, data) end;
+    val identity = {theory_id = theory_id, token = make_token ()};
+  in Theory (state, identity, ancestry, data) end;
 
 end;
 
@@ -351,105 +375,77 @@
   let
     val state = make_state ();
     val stage = next_stage state;
-  in create_thy Intset.empty PureN stage state (make_ancestry [] []) Datatab.empty end;
+  in create_thy state Intset.empty PureN stage (make_ancestry [] []) Datatab.empty end;
 
 local
 
 fun change_thy finish f thy =
   let
-    val {id, ids, name, stage} = identity_of thy;
-    val Theory (_, state, ancestry, data) = thy;
+    val {name, stage, ...} = identity_of thy;
+    val Theory (state, _, ancestry, data) = thy;
     val ancestry' =
       if stage_final stage
       then make_ancestry [thy] (extend_ancestors thy (ancestors_of thy))
       else ancestry;
-    val ids' = Intset.insert id ids;
+    val ids' = merge_ids [thy];
     val stage' = if finish then 0 else next_stage state;
     val data' = f data;
-  in create_thy ids' name stage' state ancestry' data' end;
+  in create_thy state ids' name stage' ancestry' data' end;
 
 in
 
 val update_thy = change_thy false;
-val extend_thy = change_thy false I;
 val finish_thy = change_thy true I;
 
 end;
 
 
-(* join: anonymous theory nodes *)
-
-local
-
-fun bad_join (thy1, thy2) = raise THEORY ("Cannot join theories", [thy1, thy2]);
+(* join: unfinished theory nodes *)
 
-fun join_stage (thy1, thy2) =
-  apply2 identity_of (thy1, thy2) |>
-    (fn ({name, stage, ...}, {name = name', stage = stage', ...}) =>
-      if name <> name' orelse stage_final stage orelse stage_final stage'
-      then bad_join (thy1, thy2)
-      else
-        let val state = state_of thy1
-        in {name = name, stage = next_stage state, state = state} end)
+fun join_thys [] = raise List.Empty
+  | join_thys thys =
+      let
+        val thy0 = hd thys;
+        val name0 = theory_long_name thy0;
+        val state0 = state_of thy0;
 
-fun join_ancestry thys =
-  apply2 ancestry_of thys |>
-  (fn (ancestry as {parents, ancestors}, {parents = parents', ancestors = ancestors'}) =>
-    if eq_list eq_thy (parents, parents') andalso eq_list eq_thy (ancestors, ancestors')
-    then ancestry else bad_join thys);
-
-in
+        fun ok thy =
+          not (theory_id_final (theory_id thy)) andalso
+          theory_long_name thy = name0 andalso
+          eq_ancestry (thy0, thy);
+        val _ =
+          (case filter_out ok thys of
+            [] => ()
+          | bad => raise THEORY ("Cannot join theories", bad));
 
-fun join_thys thys =
-  let
-    val ids = merge_ids thys;
-    val {name, stage, state} = join_stage thys;
-    val ancestry = join_ancestry thys;
-    val data = merge_data thys (apply2 data_of thys);
-  in create_thy ids name stage state ancestry data end;
-
-end;
+        val stage = next_stage state0;
+        val ids = merge_ids thys;
+        val data = merge_data thys;
+      in create_thy state0 ids name0 stage (ancestry_of thy0) data end;
 
 
-(* merge: named theory nodes *)
-
-local
+(* merge: finished theory nodes *)
 
-fun merge_thys thys =
-  let
-    val ids = merge_ids thys;
-    val state = state_of (#1 thys);
-    val ancestry = make_ancestry [] [];
-    val data = merge_data thys (apply2 data_of thys);
-  in create_thy ids "" 0 state ancestry data end;
-
-fun maximal_thys thys =
-  thys |> filter_out (fn thy => exists (fn thy' => proper_subthy (thy, thy')) thys);
-
-in
+fun make_parents thys =
+  let val thys' = distinct eq_thy thys
+  in thys' |> filter_out (fn thy => exists (fn thy' => proper_subthy (thy, thy')) thys') end;
 
 fun begin_thy name imports =
   if name = "" then error ("Bad theory name: " ^ quote name)
+  else if null imports then error "Missing theory imports"
   else
     let
-      val parents = maximal_thys (distinct eq_thy imports);
+      val parents = make_parents imports;
       val ancestors =
-        Library.foldl merge_ancestors ([], map ancestors_of parents)
+        Library.foldl1 merge_ancestors (map ancestors_of parents)
         |> fold extend_ancestors parents;
-
-      val thy0 =
-        (case parents of
-          [] => error "Missing theory imports"
-        | [thy] => extend_thy thy
-        | thy :: thys => Library.foldl merge_thys (thy, thys));
-      val ids = #ids (identity_of thy0);
+      val ancestry = make_ancestry parents ancestors;
 
       val state = make_state ();
       val stage = next_stage state;
-      val ancestry = make_ancestry parents ancestors;
-    in create_thy ids name stage state ancestry (data_of thy0) |> tap finish_thy end;
-
-end;
+      val ids = merge_ids parents;
+      val data = merge_data parents;
+    in create_thy state ids name stage ancestry data |> tap finish_thy end;
 
 
 (* theory data *)
@@ -641,7 +637,7 @@
 sig
   type T
   val empty: T
-  val merge: theory * theory -> T * T -> T
+  val merge: (theory * T) list -> T
 end;
 
 signature THEORY_DATA_ARGS =
@@ -670,7 +666,7 @@
     Context.Theory_Data.declare
       pos
       (Data Data.empty)
-      (fn thys => fn (Data x1, Data x2) => Data (Data.merge thys (x1, x2)))
+      (Data o Data.merge o map (fn (thy, Data x) => (thy, x)))
   end;
 
 val get = Context.Theory_Data.get kind (fn Data x => x);
@@ -684,7 +680,7 @@
   (
     type T = Data.T;
     val empty = Data.empty;
-    fun merge _ = Data.merge;
+    fun merge args = Library.foldl (fn (a, (_, b)) => Data.merge (a, b)) (#2 (hd args), tl args)
   );