abbreviate: always authentic, force expansion of internal abbreviations;
authorwenzelm
Sat, 09 Dec 2006 18:05:39 +0100
changeset 21720 059e6b8cee8e
parent 21719 b67fbfc8a126
child 21721 908a93216f00
abbreviate: always authentic, force expansion of internal abbreviations; tuned signature; tuned;
src/Pure/consts.ML
--- a/src/Pure/consts.ML	Sat Dec 09 18:05:38 2006 +0100
+++ b/src/Pure/consts.ML	Sat Dec 09 18:05:39 2006 +0100
@@ -12,11 +12,12 @@
   val eq_consts: T * T -> bool
   val abbrevs_of: T -> string list -> (term * term) list
   val dest: T ->
-   {constants: (typ * term option) NameSpace.table,
+   {constants: (typ * (term * term) option) NameSpace.table,
     constraints: typ NameSpace.table}
-  val declaration: T -> string -> typ                               (*exception TYPE*)
-  val monomorphic: T -> string -> bool                              (*exception TYPE*)
-  val constraint: T -> string -> typ                                (*exception TYPE*)
+  val the_abbreviation: T -> string -> typ * (term * term)          (*exception TYPE*)
+  val the_declaration: T -> string -> typ                           (*exception TYPE*)
+  val is_monomorphic: T -> string -> bool                           (*exception TYPE*)
+  val the_constraint: T -> string -> typ                            (*exception TYPE*)
   val space_of: T -> NameSpace.T
   val intern: T -> xstring -> string
   val extern: T -> string -> xstring
@@ -29,9 +30,9 @@
   val instance: T -> string * typ list -> typ
   val declare: NameSpace.naming -> (bstring * typ) * bool -> T -> T
   val constrain: string * typ option -> T -> T
-  val expand_abbrevs: bool -> T -> T
+  val set_expand: bool -> T -> T
   val abbreviate: Pretty.pp -> Type.tsig -> NameSpace.naming -> string ->
-    (bstring * term) * bool -> T -> ((string * typ) * term) * T
+    bstring * term -> T -> ((string * typ) * term) * T
   val hide: bool -> string -> T -> T
   val empty: T
   val merge: T * T -> T
@@ -46,8 +47,8 @@
 (* datatype T *)
 
 datatype kind =
-  LogicalConst of int list list |
-  Abbreviation of term;
+  LogicalConst of int list list |      (*typargs positions*)
+  Abbreviation of term * term * bool   (*rhs, normal rhs, force_expand*);
 
 type decl =
   (typ * kind) *
@@ -57,16 +58,16 @@
  {decls: (decl * serial) NameSpace.table,
   constraints: typ Symtab.table,
   rev_abbrevs: (term * term) list Symtab.table,
-  expand_abbrevs: bool} * stamp;
+  do_expand: bool} * stamp;
 
 fun eq_consts (Consts (_, s1), Consts (_, s2)) = s1 = s2;
 
-fun make_consts (decls, constraints, rev_abbrevs, expand_abbrevs) =
+fun make_consts (decls, constraints, rev_abbrevs, do_expand) =
   Consts ({decls = decls, constraints = constraints, rev_abbrevs = rev_abbrevs,
-    expand_abbrevs = expand_abbrevs}, stamp ());
+    do_expand = do_expand}, stamp ());
 
-fun map_consts f (Consts ({decls, constraints, rev_abbrevs, expand_abbrevs}, _)) =
-  make_consts (f (decls, constraints, rev_abbrevs, expand_abbrevs));
+fun map_consts f (Consts ({decls, constraints, rev_abbrevs, do_expand}, _)) =
+  make_consts (f (decls, constraints, rev_abbrevs, do_expand));
 
 fun abbrevs_of (Consts ({rev_abbrevs, ...}, _)) modes =
   maps (Symtab.lookup_list rev_abbrevs) modes;
@@ -75,7 +76,7 @@
 (* dest consts *)
 
 fun dest_kind (LogicalConst _) = NONE
-  | dest_kind (Abbreviation t) = SOME t;
+  | dest_kind (Abbreviation (t, t', _)) = SOME (t, t');
 
 fun dest (Consts ({decls = (space, decls), constraints, ...}, _)) =
  {constants = (space,
@@ -94,13 +95,18 @@
 fun logical_const consts c =
   (case #1 (#1 (the_const consts c)) of
     (T, LogicalConst ps) => (T, ps)
-  | _ => raise TYPE ("Illegal abbreviation: " ^ quote c, [], []));
+  | _ => raise TYPE ("Not a logical constant: " ^ quote c, [], []));
 
-val declaration = #1 oo logical_const;
+fun the_abbreviation consts c =
+  (case #1 (#1 (the_const consts c)) of
+    (T, Abbreviation (t, t', _)) => (T, (t, t'))
+  | _ => raise TYPE ("Not an abbreviated constant: " ^ quote c, [], []));
+
+val the_declaration = #1 oo logical_const;
 val type_arguments = #2 oo logical_const;
-val monomorphic = null oo type_arguments;
+val is_monomorphic = null oo type_arguments;
 
-fun constraint (consts as Consts ({constraints, ...}, _)) c =
+fun the_constraint (consts as Consts ({constraints, ...}, _)) c =
   (case Symtab.lookup constraints c of
     SOME T => T
   | NONE => #1 (#1 (#1 (the_const consts c))));
@@ -138,7 +144,7 @@
 
 (* certify *)
 
-fun certify pp tsig (consts as Consts ({expand_abbrevs, ...}, _)) =
+fun certify pp tsig (consts as Consts ({do_expand, ...}, _)) =
   let
     fun err msg (c, T) =
       raise TYPE (msg ^ " " ^ quote c ^ " :: " ^ Pretty.string_of_typ pp T, [], []);
@@ -158,10 +164,12 @@
               if not (Type.raw_instance (T', U)) then
                 err "Illegal type for constant" (c, T)
               else
-                (case (kind, expand_abbrevs) of
-                  (Abbreviation u, true) =>
-                    Term.betapplys (Envir.expand_atom T' (U, u) handle TYPE _ =>
-                      err "Illegal type for abbreviation" (c, T), args')
+                (case kind of
+                  Abbreviation (_, u, force_expand) =>
+                    if do_expand orelse force_expand then
+                      Term.betapplys (Envir.expand_atom T' (U, u) handle TYPE _ =>
+                        err "Illegal type for abbreviation" (c, T), args')
+                    else comb head
                 | _ => comb head)
             end
         | _ => comb head)
@@ -179,7 +187,7 @@
 
 fun instance consts (c, Ts) =
   let
-    val declT = declaration consts c;
+    val declT = the_declaration consts c;
     val vars = map Term.dest_TVar (typargs consts (c, declT));
   in declT |> TermSubst.instantiateT (vars ~~ Ts) end;
 
@@ -199,14 +207,14 @@
 
 (* name space *)
 
-fun hide fully c = map_consts (fn (decls, constraints, rev_abbrevs, expand_abbrevs) =>
-  (apfst (NameSpace.hide fully c) decls, constraints, rev_abbrevs, expand_abbrevs));
+fun hide fully c = map_consts (fn (decls, constraints, rev_abbrevs, do_expand) =>
+  (apfst (NameSpace.hide fully c) decls, constraints, rev_abbrevs, do_expand));
 
 
 (* declarations *)
 
 fun declare naming ((c, declT), authentic) =
-    map_consts (fn (decls, constraints, rev_abbrevs, expand_abbrevs) =>
+    map_consts (fn (decls, constraints, rev_abbrevs, do_expand) =>
   let
     fun args_of (Type (_, Ts)) pos = args_of_list Ts 0 pos
       | args_of (TVar v) pos = insert (eq_fst op =) (v, rev pos)
@@ -215,22 +223,22 @@
       | args_of_list [] _ _ = I;
     val decl =
       (((declT, LogicalConst (map #2 (rev (args_of declT [] [])))), authentic), serial ());
-  in (extend_decls naming (c, decl) decls, constraints, rev_abbrevs, expand_abbrevs) end);
+  in (extend_decls naming (c, decl) decls, constraints, rev_abbrevs, do_expand) end);
 
 
 (* constraints *)
 
 fun constrain (c, C) consts =
-  consts |> map_consts (fn (decls, constraints, rev_abbrevs, expand_abbrevs) =>
+  consts |> map_consts (fn (decls, constraints, rev_abbrevs, do_expand) =>
     (the_const consts c handle TYPE (msg, _, _) => error msg;
       (decls,
         constraints |> (case C of SOME T => Symtab.update (c, T) | NONE => Symtab.delete_safe c),
-        rev_abbrevs, expand_abbrevs)));
+        rev_abbrevs, do_expand)));
 
 
 (* abbreviations *)
 
-fun expand_abbrevs b = map_consts (fn (decls, constraints, rev_abbrevs, _) =>
+fun set_expand b = map_consts (fn (decls, constraints, rev_abbrevs, _) =>
   (decls, constraints, rev_abbrevs, b));
 
 local
@@ -243,38 +251,43 @@
       else []
   | _ => []);
 
-fun rev_abbrev const rhs =
+fun rev_abbrev lhs rhs =
   let
     fun abbrev (xs, body) =
       let val vars = fold (fn (x, T) => cons (Var ((x, 0), T))) (Term.rename_wrt_term body xs) []
-      in (Term.subst_bounds (rev vars, body), Term.list_comb (Const const, vars)) end;
+      in (Term.subst_bounds (rev vars, body), Term.list_comb (lhs, vars)) end;
   in map abbrev (strip_abss (Envir.beta_eta_contract rhs)) end;
 
 in
 
-fun abbreviate pp tsig naming mode ((c, raw_rhs), authentic) consts =
+fun abbreviate pp tsig naming mode (c, raw_rhs) consts =
   let
-    val full_c = NameSpace.full naming c;
+    val cert_term = certify pp tsig (consts |> set_expand false);
+    val expand_term = certify pp tsig (consts |> set_expand true);
+    val force_expand = (mode = #1 Syntax.internal_mode);
+
     val rhs = raw_rhs
       |> Term.map_types (Type.cert_typ tsig)
-      |> certify pp tsig (consts |> expand_abbrevs false);
-    val rhs' = rhs
-      |> certify pp tsig (consts |> expand_abbrevs true);
+      |> cert_term;
+    val rhs' = expand_term rhs;
     val T = Term.fastype_of rhs;
 
+    val const = (NameSpace.full naming c, T);
+    val lhs = Const const;
+
     fun err msg = error (msg ^ " on rhs of abbreviation:\n" ^
-      Pretty.string_of_term pp (Logic.mk_equals (Const (full_c, T), rhs)));
+      Pretty.string_of_term pp (Logic.mk_equals (lhs, rhs)));
     val _ = Term.exists_subterm Term.is_Var rhs andalso err "Illegal schematic variables"
     val _ = null (Term.hidden_polymorphism rhs T) orelse err "Extra type variables";
   in
-    consts |> map_consts (fn (decls, constraints, rev_abbrevs, expand_abbrevs) =>
+    consts |> map_consts (fn (decls, constraints, rev_abbrevs, do_expand) =>
       let
-        val decls' = decls
-          |> extend_decls naming (c, (((T, Abbreviation rhs'), authentic), serial ()));
+        val decls' = decls |> extend_decls naming
+          (c, (((T, Abbreviation (rhs, rhs', force_expand)), true), serial ()));
         val rev_abbrevs' = rev_abbrevs
-          |> fold (curry Symtab.update_list mode) (rev_abbrev (full_c, T) rhs);
-      in (decls', constraints, rev_abbrevs', expand_abbrevs) end)
-    |> pair ((full_c, T), rhs)
+          |> fold (curry Symtab.update_list mode) (rev_abbrev lhs rhs);
+      in (decls', constraints, rev_abbrevs', do_expand) end)
+    |> pair (const, rhs)
   end;
 
 end;
@@ -286,9 +299,9 @@
 
 fun merge
    (Consts ({decls = decls1, constraints = constraints1,
-      rev_abbrevs = rev_abbrevs1, expand_abbrevs = expand_abbrevs1}, _),
+      rev_abbrevs = rev_abbrevs1, do_expand = do_expand1}, _),
     Consts ({decls = decls2, constraints = constraints2,
-      rev_abbrevs = rev_abbrevs2, expand_abbrevs = expand_abbrevs2}, _)) =
+      rev_abbrevs = rev_abbrevs2, do_expand = do_expand2}, _)) =
   let
     val decls' = NameSpace.merge_tables (eq_snd (op =)) (decls1, decls2)
       handle Symtab.DUPS cs => err_dup_consts cs;
@@ -296,7 +309,7 @@
       handle Symtab.DUPS cs => err_inconsistent_constraints cs;
     val rev_abbrevs' = (rev_abbrevs1, rev_abbrevs2) |> Symtab.join
       (K (Library.merge (fn ((t, u), (t', u')) => t aconv t' andalso u aconv u')));
-    val expand_abbrevs' = expand_abbrevs1 orelse expand_abbrevs2;
-  in make_consts (decls', constraints', rev_abbrevs', expand_abbrevs') end;
+    val do_expand' = do_expand1 orelse do_expand2;
+  in make_consts (decls', constraints', rev_abbrevs', do_expand') end;
 
 end;