src/Pure/defs.ML
changeset 61261 ddb2da7cb2e4
parent 61260 e6f03fae14d5
child 61262 7bd1eb4b056e
--- a/src/Pure/defs.ML	Thu Sep 24 13:33:42 2015 +0200
+++ b/src/Pure/defs.ML	Thu Sep 24 23:33:29 2015 +0200
@@ -9,11 +9,14 @@
 sig
   datatype item_kind = Const | Type
   type item = item_kind * string
-  val item_ord: item * item -> order
   type entry = item * typ list
+  val item_kind_ord: item_kind * item_kind -> order
+  val plain_args: typ list -> bool
+  type context = Proof.context * (Name_Space.T * Name_Space.T) option
+  val space: context -> item_kind -> Name_Space.T
+  val pretty_item: context -> item -> Pretty.T
   val pretty_args: Proof.context -> typ list -> Pretty.T list
-  val pretty_entry: Proof.context -> entry -> Pretty.T
-  val plain_args: typ list -> bool
+  val pretty_entry: context -> entry -> Pretty.T
   type T
   type spec =
    {def: string option,
@@ -27,8 +30,8 @@
    {restricts: (entry * string) list,
     reducts: (entry * entry list) list}
   val empty: T
-  val merge: Proof.context -> T * T -> T
-  val define: Proof.context -> bool -> string option -> string -> entry -> entry list -> T -> T
+  val merge: context -> T * T -> T
+  val define: context -> bool -> string option -> string -> entry -> entry list -> T -> T
   val get_deps: T -> item -> (typ list * entry list) list
 end;
 
@@ -39,29 +42,41 @@
 
 datatype item_kind = Const | Type;
 type item = item_kind * string;
+type entry = item * typ list;
 
 fun item_kind_ord (Const, Type) = LESS
   | item_kind_ord (Type, Const) = GREATER
   | item_kind_ord _ = EQUAL;
 
-val item_ord = prod_ord item_kind_ord string_ord;
-val fast_item_ord = prod_ord item_kind_ord fast_string_ord;
-
-fun print_item (k, s) = if k = Const then s else "type " ^ s;
-
-structure Itemtab = Table(type key = item val ord = fast_item_ord);
+structure Itemtab = Table(type key = item val ord = prod_ord item_kind_ord fast_string_ord);
 
 
-(* type arguments *)
+(* pretty printing *)
+
+type context = Proof.context * (Name_Space.T * Name_Space.T) option;
 
-type entry = item * typ list;
+fun space (ctxt, spaces) kind =
+  (case (kind, spaces) of
+    (Const, SOME (const_space, _)) => const_space
+  | (Type, SOME (_, type_space)) => type_space
+  | (Const, NONE) => Sign.const_space (Proof_Context.theory_of ctxt)
+  | (Type, NONE) => Sign.type_space (Proof_Context.theory_of ctxt));
+
+fun pretty_item (context as (ctxt, _)) (kind, name) =
+  let val prt_name = Name_Space.pretty ctxt (space context kind) name in
+    if kind = Const then prt_name
+    else Pretty.block [Pretty.keyword1 "type", Pretty.brk 1, prt_name]
+  end;
 
 fun pretty_args ctxt args =
   if null args then []
   else [Pretty.list "(" ")" (map (Syntax.pretty_typ ctxt o Logic.unvarifyT_global) args)];
 
-fun pretty_entry ctxt (c, args) =
-  Pretty.block (Pretty.str (print_item c) :: pretty_args ctxt args);
+fun pretty_entry context (c, args) =
+  Pretty.block (pretty_item context c :: pretty_args (#1 context) args);
+
+
+(* type arguments *)
 
 fun plain_args args =
   forall Term.is_TVar args andalso not (has_duplicates (op =) args);
@@ -129,21 +144,22 @@
 
 (* specifications *)
 
-fun disjoint_specs c (i, {description = a, pos = pos_a, lhs = Ts, ...}: spec) =
+fun disjoint_specs context c (i, {description = a, pos = pos_a, lhs = Ts, ...}: spec) =
   Inttab.forall (fn (j, {description = b, pos = pos_b, lhs = Us, ...}: spec) =>
     i = j orelse disjoint_args (Ts, Us) orelse
-      error ("Clash of specifications for " ^ print_item c ^ ":\n" ^
+      error ("Clash of specifications for " ^ Pretty.str_of (pretty_item context c) ^ ":\n" ^
         "  " ^ quote a ^ Position.here pos_a ^ "\n" ^
         "  " ^ quote b ^ Position.here pos_b));
 
-fun join_specs c ({specs = specs1, restricts, reducts}, {specs = specs2, ...}: def) =
+fun join_specs context c ({specs = specs1, restricts, reducts}, {specs = specs2, ...}: def) =
   let
     val specs' =
-      Inttab.fold (fn spec2 => (disjoint_specs c spec2 specs1; Inttab.update spec2)) specs2 specs1;
+      Inttab.fold (fn spec2 => (disjoint_specs context c spec2 specs1; Inttab.update spec2))
+        specs2 specs1;
   in make_def (specs', restricts, reducts) end;
 
-fun update_specs c spec = map_def c (fn (specs, restricts, reducts) =>
-  (disjoint_specs c spec specs; (Inttab.update spec specs, restricts, reducts)));
+fun update_specs context c spec = map_def c (fn (specs, restricts, reducts) =>
+  (disjoint_specs context c spec specs; (Inttab.update spec specs, restricts, reducts)));
 
 
 (* normalized dependencies: reduction with well-formedness check *)
@@ -151,23 +167,24 @@
 local
 
 val prt = Pretty.string_of oo pretty_entry;
-fun err ctxt (c, args) (d, Us) s1 s2 =
-  error (s1 ^ " dependency of " ^ prt ctxt (c, args) ^ " -> " ^ prt ctxt (d, Us) ^ s2);
 
-fun acyclic ctxt (c, args) (d, Us) =
+fun err context (c, args) (d, Us) s1 s2 =
+  error (s1 ^ " dependency of " ^ prt context (c, args) ^ " -> " ^ prt context (d, Us) ^ s2);
+
+fun acyclic context (c, args) (d, Us) =
   c <> d orelse
   is_none (match_args (args, Us)) orelse
-  err ctxt (c, args) (d, Us) "Circular" "";
+  err context (c, args) (d, Us) "Circular" "";
 
-fun wellformed ctxt defs (c, args) (d, Us) =
+fun wellformed context defs (c, args) (d, Us) =
   plain_args Us orelse
   (case find_first (fn (Ts, _) => not (disjoint_args (Ts, Us))) (restricts_of defs d) of
     SOME (Ts, description) =>
-      err ctxt (c, args) (d, Us) "Malformed"
-        ("\n(restriction " ^ prt ctxt (d, Ts) ^ " from " ^ quote description ^ ")")
+      err context (c, args) (d, Us) "Malformed"
+        ("\n(restriction " ^ prt context (d, Ts) ^ " from " ^ quote description ^ ")")
   | NONE => true);
 
-fun reduction ctxt defs const deps =
+fun reduction context defs const deps =
   let
     fun reduct Us (Ts, rhs) =
       (case match_args (Ts, Us) of
@@ -180,17 +197,17 @@
       if forall (is_none o #1) reds then NONE
       else SOME (fold_rev
         (fn (NONE, dp) => insert (op =) dp | (SOME dps, _) => fold (insert (op =)) dps) reds []);
-    val _ = forall (acyclic ctxt const) (the_default deps deps');
+    val _ = forall (acyclic context const) (the_default deps deps');
   in deps' end;
 
 in
 
-fun normalize ctxt =
+fun normalize context =
   let
     fun norm_update (c, {reducts, ...}: def) (changed, defs) =
       let
         val reducts' = reducts |> map (fn (args, deps) =>
-          (args, perhaps (reduction ctxt defs (c, args)) deps));
+          (args, perhaps (reduction context defs (c, args)) deps));
       in
         if reducts = reducts' then (changed, defs)
         else (true, defs |> map_def c (fn (specs, restricts, _) => (specs, restricts, reducts')))
@@ -200,38 +217,38 @@
         (true, defs') => norm_all defs'
       | (false, _) => defs);
     fun check defs (c, {reducts, ...}: def) =
-      reducts |> forall (fn (args, deps) => forall (wellformed ctxt defs (c, args)) deps);
+      reducts |> forall (fn (args, deps) => forall (wellformed context defs (c, args)) deps);
   in norm_all #> (fn defs => tap (Itemtab.forall (check defs)) defs) end;
 
-fun dependencies ctxt (c, args) restr deps =
+fun dependencies context (c, args) restr deps =
   map_def c (fn (specs, restricts, reducts) =>
     let
       val restricts' = Library.merge (op =) (restricts, restr);
       val reducts' = insert (op =) (args, deps) reducts;
     in (specs, restricts', reducts') end)
-  #> normalize ctxt;
+  #> normalize context;
 
 end;
 
 
 (* merge *)
 
-fun merge ctxt (Defs defs1, Defs defs2) =
+fun merge context (Defs defs1, Defs defs2) =
   let
     fun add_deps (c, args) restr deps defs =
       if AList.defined (op =) (reducts_of defs c) args then defs
-      else dependencies ctxt (c, args) restr deps defs;
+      else dependencies context (c, args) restr deps defs;
     fun add_def (c, {restricts, reducts, ...}: def) =
       fold (fn (args, deps) => add_deps (c, args) restricts deps) reducts;
   in
-    Defs (Itemtab.join join_specs (defs1, defs2)
-      |> normalize ctxt |> Itemtab.fold add_def defs2)
+    Defs (Itemtab.join (join_specs context) (defs1, defs2)
+      |> normalize context |> Itemtab.fold add_def defs2)
   end;
 
 
 (* define *)
 
-fun define ctxt unchecked def description (c, args) deps (Defs defs) =
+fun define context unchecked def description (c, args) deps (Defs defs) =
   let
     val pos = Position.thread_data ();
     val restr =
@@ -240,8 +257,8 @@
       then [] else [(args, description)];
     val spec =
       (serial (), {def = def, description = description, pos = pos, lhs = args, rhs = deps});
-    val defs' = defs |> update_specs c spec;
-  in Defs (defs' |> (if unchecked then I else dependencies ctxt (c, args) restr deps)) end;
+    val defs' = defs |> update_specs context c spec;
+  in Defs (defs' |> (if unchecked then I else dependencies context (c, args) restr deps)) end;
 
 fun get_deps (Defs defs) c = reducts_of defs c;