support n-ary merge theory data;
authorwenzelm
Thu, 20 Apr 2023 21:26:35 +0200
changeset 77895 655bd3b0671b
parent 77894 186bd4012b78
child 77896 a9626bcb0c3b
support n-ary merge theory data; less redundant use of ids and stages;
src/HOL/ex/Join_Theory.thy
src/Pure/Isar/locale.ML
src/Pure/axclass.ML
src/Pure/context.ML
src/Pure/sign.ML
src/Pure/theory.ML
--- a/src/HOL/ex/Join_Theory.thy	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/HOL/ex/Join_Theory.thy	Thu Apr 20 21:26:35 2023 +0200
@@ -35,7 +35,7 @@
 setup \<open>
   fn thy =>
     let val forked_thys = Par_List.map (fn i => Named_Target.theory_map (spec i) thy) (1 upto 10)
-    in Theory.join_theory forked_thys end
+    in Context.join_thys forked_thys end
 \<close>
 
 term test1
--- a/src/Pure/Isar/locale.ML	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/Pure/Isar/locale.ML	Thu Apr 20 21:26:35 2023 +0200
@@ -355,8 +355,9 @@
     unique registration serial points to mixin list*)
   type T = reg Idtab.table * mixins;
   val empty: T = (Idtab.empty, Inttab.empty);
-  fun merge old_thys =
+  fun merge args =
     let
+      val ctxt0 = Syntax.init_pretty_global (#1 (hd args));
       fun recursive_merge ((regs1, mixins1), (regs2, mixins2)) : T =
         (Idtab.merge eq_reg (regs1, regs2), merge_mixins (mixins1, mixins2))
         handle Idtab.DUP id =>
@@ -373,9 +374,9 @@
             val _ =
               warning ("Removed duplicate interpretation after retrieving its mixins" ^
                 Position.here_list [#pos reg1, #pos reg2] ^ ":\n  " ^
-                Pretty.string_of (pretty_reg_inst (Syntax.init_pretty_global (#1 old_thys)) [] id));
+                Pretty.string_of (pretty_reg_inst ctxt0 [] id));
           in recursive_merge ((regs1, mixins1), (regs2', mixins2')) end;
-    in recursive_merge end;
+    in Library.foldl1 recursive_merge (map #2 args) end;
 );
 
 structure Local_Registrations = Proof_Data
--- a/src/Pure/axclass.ML	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/Pure/axclass.ML	Thu Apr 20 21:26:35 2023 +0200
@@ -74,6 +74,8 @@
       (*constant name ~> type constructor ~> (constant name, equation)*)
     (string * string) Symtab.table (*constant name ~> (constant name, type constructor)*)};
 
+fun rep_data (Data args) = args;
+
 fun make_data (axclasses, params, inst_params) =
   Data {axclasses = axclasses, params = params, inst_params = inst_params};
 
@@ -81,22 +83,23 @@
 (
   type T = data;
   val empty = make_data (Symtab.empty, [], (Symtab.empty, Symtab.empty));
-  fun merge old_thys
-      (Data {axclasses = axclasses1, params = params1, inst_params = inst_params1},
-       Data {axclasses = axclasses2, params = params2, inst_params = inst_params2}) =
+  fun merge args =
     let
-      val old_ctxt = Syntax.init_pretty_global (fst old_thys);
+      val ctxt0 = Syntax.init_pretty_global (#1 (hd args));
 
-      val axclasses' = Symtab.merge (K true) (axclasses1, axclasses2);
-      val params' =
+      fun merge_params (params1, params2) =
         if null params1 then params2
         else
-          fold_rev (fn p => if member (op =) params1 p then I else add_param old_ctxt p)
+          fold_rev (fn p => if member (op =) params1 p then I else add_param ctxt0 p)
             params2 params1;
 
-      val inst_params' =
+      fun merge_inst_params (inst_params1, inst_params2) =
         (Symtab.join (K (Symtab.merge (K true))) (#1 inst_params1, #1 inst_params2),
           Symtab.merge (K true) (#2 inst_params1, #2 inst_params2));
+
+      val axclasses' = Library.foldl1 (Symtab.merge (K true)) (map (#axclasses o rep_data o #2) args);
+      val params' = Library.foldl1 merge_params (map (#params o rep_data o #2) args);
+      val inst_params' = Library.foldl1 merge_inst_params (map (#inst_params o rep_data o #2) args);
     in make_data (axclasses', params', inst_params') end;
 );
 
@@ -116,11 +119,11 @@
   map_data (fn (axclasses, params, inst_params) =>
     (axclasses, params, f inst_params));
 
-val rep_data = Data.get #> (fn Data args => args);
+val rep_theory_data = Data.get #> rep_data;
 
-val axclasses_of = #axclasses o rep_data;
-val params_of = #params o rep_data;
-val inst_params_of = #inst_params o rep_data;
+val axclasses_of = #axclasses o rep_theory_data;
+val params_of = #params o rep_theory_data;
+val inst_params_of = #inst_params o rep_theory_data;
 
 
 (* axclasses with parameters *)
--- a/src/Pure/context.ML	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/Pure/context.ML	Thu Apr 20 21:26:35 2023 +0200
@@ -51,7 +51,7 @@
   val subthy: theory * theory -> bool
   val trace_theories: bool Unsynchronized.ref
   val theories_trace: unit -> {active_positions: Position.T list, active: int, total: int}
-  val join_thys: theory * theory -> theory
+  val join_thys: theory list -> theory
   val begin_thy: string -> theory list -> theory
   val finish_thy: theory -> theory
   val theory_data_sizeof1: theory -> (Position.T * int) list
@@ -95,7 +95,7 @@
   include CONTEXT
   structure Theory_Data:
   sig
-    val declare: Position.T -> Any.T -> (theory * theory -> Any.T * Any.T -> Any.T) -> serial
+    val declare: Position.T -> Any.T -> ((theory * Any.T) list -> Any.T) -> serial
     val get: serial -> (Any.T -> 'a) -> theory -> 'a
     val put: serial -> ('a -> Any.T) -> 'a -> theory -> theory
   end
@@ -118,6 +118,19 @@
 
 (** datatype theory **)
 
+(* implicit state *)
+
+type state = {stage: int} Synchronized.var;
+
+fun make_state () : state =
+  Synchronized.var "Context.state" {stage = 0};
+
+fun next_stage (state: state) =
+  Synchronized.change_result state (fn {stage} => (stage + 1, {stage = stage + 1}));
+
+
+(* theory_id *)
+
 abstype theory_id =
   Theory_Id of
    {id: serial,                   (*identifier*)
@@ -131,13 +144,16 @@
 
 end;
 
+
+(* theory *)
+
 datatype theory =
   Theory of
+   (*allocation state*)
+   state *
    (*identity*)
    {theory_id: theory_id,
     token: Position.T Unsynchronized.ref} *
-   (*allocation state*)
-   {next_stage: unit -> int} *
    (*ancestry*)
    {parents: theory list,         (*immediate predecessors*)
     ancestors: theory list} *     (*all predecessors -- canonical reverse order*)
@@ -148,16 +164,13 @@
 
 fun rep_theory (Theory args) = args;
 
-val theory_identity = #1 o rep_theory;
+val state_of = #1 o rep_theory;
+val theory_identity = #2 o rep_theory;
 val theory_id = #theory_id o theory_identity;
 val identity_of = rep_theory_id o theory_id;
-val state_of = #2 o rep_theory;
 val ancestry_of = #3 o rep_theory;
 val data_of = #4 o rep_theory;
 
-fun make_state () = {next_stage = Counter.make ()};
-fun next_stage {next_stage: unit -> int} = next_stage ();
-
 fun make_ancestry parents ancestors = {parents = parents, ancestors = ancestors};
 
 fun stage_final stage = stage = 0;
@@ -214,17 +227,11 @@
   else error ("Unfinished theory " ^ quote name);
 
 
-(* build ids *)
+(* identity *)
 
-val merge_ids =
-  apply2 identity_of #>
-  (fn ({id = id1, ids = ids1, ...}, {id = id2, ids = ids2, ...}) =>
-    Intset.merge (ids1, ids2)
-    |> Intset.insert id1
-    |> Intset.insert id2);
-
-
-(* equality and inclusion *)
+fun merge_ids thys =
+  fold (identity_of #> (fn {id, ids, ...} => fn acc => Intset.merge (acc, ids) |> Intset.insert id))
+    thys Intset.empty;
 
 val eq_thy_id = op = o apply2 (#id o rep_theory_id);
 val eq_thy = op = o apply2 (#id o identity_of);
@@ -250,6 +257,11 @@
 
 val merge_ancestors = merge eq_thy_consistent;
 
+val eq_ancestry =
+  apply2 ancestry_of #>
+    (fn ({parents, ancestors}, {parents = parents', ancestors = ancestors'}) =>
+      eq_list eq_thy (parents, parents') andalso eq_list eq_thy (ancestors, ancestors'));
+
 
 
 (** theory data **)
@@ -263,24 +275,25 @@
 type kind =
  {pos: Position.T,
   empty: Any.T,
-  merge: theory * theory -> Any.T * Any.T -> Any.T};
+  merge: (theory * Any.T) list -> Any.T};
 
 val kinds = Synchronized.var "Theory_Data" (Datatab.empty: kind Datatab.table);
 
-fun invoke name f k x =
+fun the_kind k =
   (case Datatab.lookup (Synchronized.value kinds) k of
-    SOME kind =>
-      if ! timing andalso name <> "" then
-        Timing.cond_timeit true ("Theory_Data." ^ name ^ Position.here (#pos kind))
-          (fn () => f kind x)
-      else f kind x
+    SOME kind => kind
   | NONE => raise Fail "Invalid theory data identifier");
 
 in
 
-fun invoke_pos k = invoke "" (K o #pos) k ();
-fun invoke_empty k = invoke "" (K o #empty) k ();
-fun invoke_merge thys = invoke "merge" (fn kind => #merge kind thys);
+val invoke_pos = #pos o the_kind;
+val invoke_empty = #empty o the_kind;
+
+fun invoke_merge kind args =
+  if ! timing then
+    Timing.cond_timeit true ("Theory_Data.merge" ^ Position.here (#pos kind))
+      (fn () => #merge kind args)
+  else #merge kind args;
 
 fun declare_data pos empty merge =
   let
@@ -289,12 +302,23 @@
     val _ = Synchronized.change kinds (Datatab.update (k, kind));
   in k end;
 
+fun lookup_data k thy = Datatab.lookup (data_of thy) k;
+
 fun get_data k thy =
-  (case Datatab.lookup (data_of thy) k of
+  (case lookup_data k thy of
     SOME x => x
   | NONE => invoke_empty k);
 
-fun merge_data thys = Datatab.join (invoke_merge thys);
+fun merge_data [] = Datatab.empty
+  | merge_data [thy] = data_of thy
+  | merge_data thys =
+      let
+        fun merge (k, kind) data =
+          (case map_filter (fn thy => lookup_data k thy |> Option.map (pair thy)) thys of
+            [] => data
+          | [(_, x)] => Datatab.default (k, x) data
+          | args => Datatab.update (k, invoke_merge kind args) data);
+      in Datatab.fold merge (Synchronized.value kinds) (data_of (hd thys)) end;
 
 end;
 
@@ -336,11 +360,11 @@
      total = length trace}
   end;
 
-fun create_thy ids name stage state ancestry data =
+fun create_thy state ids name stage ancestry data =
   let
     val theory_id = make_theory_id {id = serial (), ids = ids, name = name, stage = stage};
-    val token = make_token ();
-  in Theory ({theory_id = theory_id, token = token}, state, ancestry, data) end;
+    val identity = {theory_id = theory_id, token = make_token ()};
+  in Theory (state, identity, ancestry, data) end;
 
 end;
 
@@ -351,105 +375,77 @@
   let
     val state = make_state ();
     val stage = next_stage state;
-  in create_thy Intset.empty PureN stage state (make_ancestry [] []) Datatab.empty end;
+  in create_thy state Intset.empty PureN stage (make_ancestry [] []) Datatab.empty end;
 
 local
 
 fun change_thy finish f thy =
   let
-    val {id, ids, name, stage} = identity_of thy;
-    val Theory (_, state, ancestry, data) = thy;
+    val {name, stage, ...} = identity_of thy;
+    val Theory (state, _, ancestry, data) = thy;
     val ancestry' =
       if stage_final stage
       then make_ancestry [thy] (extend_ancestors thy (ancestors_of thy))
       else ancestry;
-    val ids' = Intset.insert id ids;
+    val ids' = merge_ids [thy];
     val stage' = if finish then 0 else next_stage state;
     val data' = f data;
-  in create_thy ids' name stage' state ancestry' data' end;
+  in create_thy state ids' name stage' ancestry' data' end;
 
 in
 
 val update_thy = change_thy false;
-val extend_thy = change_thy false I;
 val finish_thy = change_thy true I;
 
 end;
 
 
-(* join: anonymous theory nodes *)
-
-local
-
-fun bad_join (thy1, thy2) = raise THEORY ("Cannot join theories", [thy1, thy2]);
+(* join: unfinished theory nodes *)
 
-fun join_stage (thy1, thy2) =
-  apply2 identity_of (thy1, thy2) |>
-    (fn ({name, stage, ...}, {name = name', stage = stage', ...}) =>
-      if name <> name' orelse stage_final stage orelse stage_final stage'
-      then bad_join (thy1, thy2)
-      else
-        let val state = state_of thy1
-        in {name = name, stage = next_stage state, state = state} end)
+fun join_thys [] = raise List.Empty
+  | join_thys thys =
+      let
+        val thy0 = hd thys;
+        val name0 = theory_long_name thy0;
+        val state0 = state_of thy0;
 
-fun join_ancestry thys =
-  apply2 ancestry_of thys |>
-  (fn (ancestry as {parents, ancestors}, {parents = parents', ancestors = ancestors'}) =>
-    if eq_list eq_thy (parents, parents') andalso eq_list eq_thy (ancestors, ancestors')
-    then ancestry else bad_join thys);
-
-in
+        fun ok thy =
+          not (theory_id_final (theory_id thy)) andalso
+          theory_long_name thy = name0 andalso
+          eq_ancestry (thy0, thy);
+        val _ =
+          (case filter_out ok thys of
+            [] => ()
+          | bad => raise THEORY ("Cannot join theories", bad));
 
-fun join_thys thys =
-  let
-    val ids = merge_ids thys;
-    val {name, stage, state} = join_stage thys;
-    val ancestry = join_ancestry thys;
-    val data = merge_data thys (apply2 data_of thys);
-  in create_thy ids name stage state ancestry data end;
-
-end;
+        val stage = next_stage state0;
+        val ids = merge_ids thys;
+        val data = merge_data thys;
+      in create_thy state0 ids name0 stage (ancestry_of thy0) data end;
 
 
-(* merge: named theory nodes *)
-
-local
+(* merge: finished theory nodes *)
 
-fun merge_thys thys =
-  let
-    val ids = merge_ids thys;
-    val state = state_of (#1 thys);
-    val ancestry = make_ancestry [] [];
-    val data = merge_data thys (apply2 data_of thys);
-  in create_thy ids "" 0 state ancestry data end;
-
-fun maximal_thys thys =
-  thys |> filter_out (fn thy => exists (fn thy' => proper_subthy (thy, thy')) thys);
-
-in
+fun make_parents thys =
+  let val thys' = distinct eq_thy thys
+  in thys' |> filter_out (fn thy => exists (fn thy' => proper_subthy (thy, thy')) thys') end;
 
 fun begin_thy name imports =
   if name = "" then error ("Bad theory name: " ^ quote name)
+  else if null imports then error "Missing theory imports"
   else
     let
-      val parents = maximal_thys (distinct eq_thy imports);
+      val parents = make_parents imports;
       val ancestors =
-        Library.foldl merge_ancestors ([], map ancestors_of parents)
+        Library.foldl1 merge_ancestors (map ancestors_of parents)
         |> fold extend_ancestors parents;
-
-      val thy0 =
-        (case parents of
-          [] => error "Missing theory imports"
-        | [thy] => extend_thy thy
-        | thy :: thys => Library.foldl merge_thys (thy, thys));
-      val ids = #ids (identity_of thy0);
+      val ancestry = make_ancestry parents ancestors;
 
       val state = make_state ();
       val stage = next_stage state;
-      val ancestry = make_ancestry parents ancestors;
-    in create_thy ids name stage state ancestry (data_of thy0) |> tap finish_thy end;
-
-end;
+      val ids = merge_ids parents;
+      val data = merge_data parents;
+    in create_thy state ids name stage ancestry data |> tap finish_thy end;
 
 
 (* theory data *)
@@ -641,7 +637,7 @@
 sig
   type T
   val empty: T
-  val merge: theory * theory -> T * T -> T
+  val merge: (theory * T) list -> T
 end;
 
 signature THEORY_DATA_ARGS =
@@ -670,7 +666,7 @@
     Context.Theory_Data.declare
       pos
       (Data Data.empty)
-      (fn thys => fn (Data x1, Data x2) => Data (Data.merge thys (x1, x2)))
+      (Data o Data.merge o map (fn (thy, Data x) => (thy, x)))
   end;
 
 val get = Context.Theory_Data.get kind (fn Data x => x);
@@ -684,7 +680,7 @@
   (
     type T = Data.T;
     val empty = Data.empty;
-    fun merge _ = Data.merge;
+    fun merge args = Library.foldl (fn (a, (_, b)) => Data.merge (a, b)) (#2 (hd args), tl args)
   );
 
 
--- a/src/Pure/sign.ML	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/Pure/sign.ML	Thu Apr 20 21:26:35 2023 +0200
@@ -129,24 +129,23 @@
   tsig: Type.tsig,              (*order-sorted signature of types*)
   consts: Consts.T};            (*polymorphic constants*)
 
+fun rep_sign (Sign args) = args;
 fun make_sign (syn, tsig, consts) = Sign {syn = syn, tsig = tsig, consts = consts};
 
 structure Data = Theory_Data'
 (
   type T = sign;
   val empty = make_sign (Syntax.empty_syntax, Type.empty_tsig, Consts.empty);
-  fun merge old_thys (sign1, sign2) =
+  fun merge args =
     let
-      val Sign {syn = syn1, tsig = tsig1, consts = consts1} = sign1;
-      val Sign {syn = syn2, tsig = tsig2, consts = consts2} = sign2;
-
-      val syn = Syntax.merge_syntax (syn1, syn2);
-      val tsig = Type.merge_tsig (Context.Theory (fst old_thys)) (tsig1, tsig2);
-      val consts = Consts.merge (consts1, consts2);
-    in make_sign (syn, tsig, consts) end;
+      val context0 = Context.Theory (#1 (hd args));
+      val syn' = Library.foldl1 Syntax.merge_syntax (map (#syn o rep_sign o #2) args);
+      val tsig' = Library.foldl1 (Type.merge_tsig context0) (map (#tsig o rep_sign o #2) args);
+      val consts' = Library.foldl1 Consts.merge (map (#consts o rep_sign o #2) args);
+    in make_sign (syn', tsig', consts') end;
 );
 
-fun rep_sg thy = Data.get thy |> (fn Sign args => args);
+val rep_sg = rep_sign o Data.get;
 
 fun map_sign f = Data.map (fn Sign {syn, tsig, consts} => make_sign (f (syn, tsig, consts)));
 
--- a/src/Pure/theory.ML	Thu Apr 20 15:26:34 2023 +0200
+++ b/src/Pure/theory.ML	Thu Apr 20 21:26:35 2023 +0200
@@ -22,7 +22,6 @@
   val defs_of: theory -> Defs.T
   val at_begin: (theory -> theory option) -> theory -> theory
   val at_end: (theory -> theory option) -> theory -> theory
-  val join_theory: theory list -> theory
   val begin_theory: string * Position.T -> theory list -> theory
   val end_theory: theory -> theory
   val add_axiom: Proof.context -> binding * term -> theory -> theory
@@ -83,6 +82,8 @@
   defs: Defs.T,
   wrappers: wrapper list * wrapper list};
 
+fun rep_thy (Thy args) = args;
+
 fun make_thy (pos, id, axioms, defs, wrappers) =
   Thy {pos = pos, id = id, axioms = axioms, defs = defs, wrappers = wrappers};
 
@@ -90,19 +91,22 @@
 (
   type T = thy;
   val empty = make_thy (Position.none, 0, Name_Space.empty_table Markup.axiomN, Defs.empty, ([], []));
-  fun merge old_thys (thy1, thy2) =
+  fun merge args =
     let
-      val Thy {pos, id, axioms = axioms1, defs = defs1, wrappers = (bgs1, ens1)} = thy1;
-      val Thy {pos = _, id = _, axioms = axioms2, defs = defs2, wrappers = (bgs2, ens2)} = thy2;
+      val thy0 = #1 (hd args);
+      val {pos, id, ...} = rep_thy (#2 (hd args));
 
-      val axioms' = Name_Space.merge_tables (axioms1, axioms2);
-      val defs' = Defs.merge (Defs.global_context (fst old_thys)) (defs1, defs2);
-      val bgs' = Library.merge (eq_snd op =) (bgs1, bgs2);
-      val ens' = Library.merge (eq_snd op =) (ens1, ens2);
+      val merge_defs = Defs.merge (Defs.global_context thy0);
+      val merge_wrappers = Library.merge (eq_snd op =);
+
+      val axioms' = Library.foldl1 Name_Space.merge_tables (map (#axioms o rep_thy o #2) args);
+      val defs' = Library.foldl1 merge_defs (map (#defs o rep_thy o #2) args);
+      val bgs' = Library.foldl1 merge_wrappers (map (#1 o #wrappers o rep_thy o #2) args);
+      val ens' = Library.foldl1 merge_wrappers (map (#2 o #wrappers o rep_thy o #2) args);
     in make_thy (pos, id, axioms', defs', (bgs', ens')) end;
 );
 
-fun rep_theory thy = Thy.get thy |> (fn Thy args => args);
+val rep_theory = rep_thy o Thy.get;
 
 fun map_thy f = Thy.map (fn (Thy {pos, id, axioms, defs, wrappers}) =>
   make_thy (f (pos, id, axioms, defs, wrappers)));
@@ -162,13 +166,6 @@
 val defs_of = #defs o rep_theory;
 
 
-(* join theory *)
-
-fun join_theory [] = raise List.Empty
-  | join_theory [thy] = thy
-  | join_theory thys = foldl1 Context.join_thys thys;
-
-
 (* begin/end theory *)
 
 val begin_wrappers = rev o #1 o #wrappers o rep_theory;