src/Pure/variable.ML
author wenzelm
Tue Jul 11 12:17:09 2006 +0200 (2006-07-11)
changeset 20084 aa320957f00c
parent 20003 aac2c0d29751
child 20102 6676a17dfc88
permissions -rw-r--r--
maintain Name.context for fixes/defaults;
more efficient inventing/renaming of local names (cf. name.ML);
     1 (*  Title:      Pure/variable.ML
     2     ID:         $Id$
     3     Author:     Makarius
     4 
     5 Fixed type/term variables and polymorphic term abbreviations.
     6 *)
     7 
     8 signature VARIABLE =
     9 sig
    10   val is_body: Context.proof -> bool
    11   val set_body: bool -> Context.proof -> Context.proof
    12   val restore_body: Context.proof -> Context.proof -> Context.proof
    13   val fixes_of: Context.proof -> (string * string) list
    14   val binds_of: Context.proof -> (typ * term) Vartab.table
    15   val defaults_of: Context.proof ->
    16     typ Vartab.table * sort Vartab.table * string list * string list Symtab.table * Name.context
    17   val used_types: Context.proof -> string list
    18   val is_declared: Context.proof -> string -> bool
    19   val is_fixed: Context.proof -> string -> bool
    20   val def_sort: Context.proof -> indexname -> sort option
    21   val def_type: Context.proof -> bool -> indexname -> typ option
    22   val default_type: Context.proof -> string -> typ option
    23   val declare_type: typ -> Context.proof -> Context.proof
    24   val declare_syntax: term -> Context.proof -> Context.proof
    25   val declare_term: term -> Context.proof -> Context.proof
    26   val declare_thm: thm -> Context.proof -> Context.proof
    27   val thm_context: thm -> Context.proof
    28   val rename_wrt: Context.proof -> term list -> (string * 'a) list -> (string * 'a) list
    29   val add_fixes: string list -> Context.proof -> string list * Context.proof
    30   val invent_fixes: string list -> Context.proof -> string list * Context.proof
    31   val invent_types: sort list -> Context.proof -> (string * sort) list * Context.proof
    32   val export_inst: Context.proof -> Context.proof -> string list * string list
    33   val exportT_inst: Context.proof -> Context.proof -> string list
    34   val export_terms: Context.proof -> Context.proof -> term list -> term list
    35   val exportT_terms: Context.proof -> Context.proof -> term list -> term list
    36   val exportT: Context.proof -> Context.proof -> thm list -> thm list
    37   val export: Context.proof -> Context.proof -> thm list -> thm list
    38   val importT_inst: term list -> Context.proof -> ((indexname * sort) * typ) list * Context.proof
    39   val import_inst: bool -> term list -> Context.proof ->
    40     (((indexname * sort) * typ) list * ((indexname * typ) * term) list) * Context.proof
    41   val importT_terms: term list -> Context.proof -> term list * Context.proof
    42   val import_terms: bool -> term list -> Context.proof -> term list * Context.proof
    43   val importT: thm list -> Context.proof -> thm list * Context.proof
    44   val import: bool -> thm list -> Context.proof -> thm list * Context.proof
    45   val tradeT: Context.proof -> (thm list -> thm list) -> thm list -> thm list
    46   val trade: Context.proof -> (thm list -> thm list) -> thm list -> thm list
    47   val warn_extra_tfrees: Context.proof -> Context.proof -> unit
    48   val monomorphic: Context.proof -> term list -> term list
    49   val polymorphic: Context.proof -> term list -> term list
    50   val hidden_polymorphism: term -> typ -> (indexname * sort) list
    51   val add_binds: (indexname * term option) list -> Context.proof -> Context.proof
    52   val expand_binds: Context.proof -> term -> term
    53 end;
    54 
    55 structure Variable: VARIABLE =
    56 struct
    57 
    58 
    59 (** local context data **)
    60 
    61 datatype data = Data of
    62  {is_body: bool,                                (*inner body mode*)
    63   fixes: (string * string) list * Name.context, (*term fixes -- extern/intern*)
    64   binds: (typ * term) Vartab.table,             (*term bindings*)
    65   defaults:
    66     typ Vartab.table *                          (*type constraints*)
    67     sort Vartab.table *                         (*default sorts*)
    68     string list *                               (*used type variables*)
    69     string list Symtab.table *                  (*occurrences of type variables in term variables*)
    70     Name.context};                              (*type/term variable names*)
    71 
    72 fun make_data (is_body, fixes, binds, defaults) =
    73   Data {is_body = is_body, fixes = fixes, binds = binds, defaults = defaults};
    74 
    75 structure Data = ProofDataFun
    76 (
    77   val name = "Pure/variable";
    78   type T = data;
    79   fun init thy =
    80     make_data (false, ([], Name.context), Vartab.empty,
    81       (Vartab.empty, Vartab.empty, [], Symtab.empty, Name.make_context ["", "'"]));
    82   fun print _ _ = ();
    83 );
    84 
    85 val _ = Context.add_setup Data.init;
    86 
    87 fun map_data f =
    88   Data.map (fn Data {is_body, fixes, binds, defaults} =>
    89     make_data (f (is_body, fixes, binds, defaults)));
    90 
    91 fun map_fixes f = map_data (fn (is_body, fixes, binds, defaults) =>
    92   (is_body, f fixes, binds, defaults));
    93 
    94 fun map_binds f = map_data (fn (is_body, fixes, binds, defaults) =>
    95   (is_body, fixes, f binds, defaults));
    96 
    97 fun map_defaults f = map_data (fn (is_body, fixes, binds, defaults) =>
    98   (is_body, fixes, binds, f defaults));
    99 
   100 fun rep_data ctxt = Data.get ctxt |> (fn Data args => args);
   101 
   102 val is_body = #is_body o rep_data;
   103 fun set_body b = map_data (fn (_, fixes, binds, defaults) =>
   104   (b, fixes, binds, defaults));
   105 fun restore_body ctxt = set_body (is_body ctxt);
   106 
   107 val fixes_of = #1 o #fixes o rep_data;
   108 val fixed_names_of = #2 o #fixes o rep_data;
   109 
   110 val binds_of = #binds o rep_data;
   111 
   112 val defaults_of = #defaults o rep_data;
   113 val used_types = #3 o defaults_of;
   114 val type_occs_of = #4 o defaults_of;
   115 
   116 fun is_declared ctxt x = Vartab.defined (#1 (defaults_of ctxt)) (x, ~1);
   117 fun is_fixed ctxt x = exists (fn (_, y) => x = y) (fixes_of ctxt);
   118 
   119 
   120 
   121 (** declarations **)
   122 
   123 (* default sorts and types *)
   124 
   125 val def_sort = Vartab.lookup o #2 o defaults_of;
   126 
   127 fun def_type ctxt pattern xi =
   128   let val {binds, defaults = (types, _, _, _, _), ...} = rep_data ctxt in
   129     (case Vartab.lookup types xi of
   130       NONE =>
   131         if pattern then NONE
   132         else Vartab.lookup binds xi |> Option.map (TypeInfer.polymorphicT o #1)
   133     | some => some)
   134   end;
   135 
   136 fun default_type ctxt x = Vartab.lookup (#1 (defaults_of ctxt)) (x, ~1);
   137 
   138 
   139 (* declare types/terms *)
   140 
   141 local
   142 
   143 val ins_types = fold_aterms
   144   (fn Free (x, T) => Vartab.update ((x, ~1), T)
   145     | Var v => Vartab.update v
   146     | _ => I);
   147 
   148 val ins_sorts = fold_atyps
   149   (fn TFree (x, S) => Vartab.update ((x, ~1), S)
   150     | TVar v => Vartab.update v
   151     | _ => I);
   152 
   153 val ins_used = fold_atyps
   154   (fn TFree (x, _) => insert (op =) x | _ => I);
   155 
   156 val ins_occs = fold_term_types (fn t =>
   157   let val x = case t of Free (x, _) => x | _ => ""
   158   in fold_atyps (fn TFree (a, _) => Symtab.insert_list (op =) (a, x) | _ => I) end);
   159 
   160 fun ins_skolem def_ty = fold_rev (fn (x, x') =>
   161   (case def_ty x' of
   162     SOME T => Vartab.update ((x, ~1), T)
   163   | NONE => I));
   164 
   165 val ins_namesT = fold_atyps
   166   (fn TFree (x, _) => Name.declare x | _ => I);
   167 
   168 fun ins_names t =
   169   fold_types ins_namesT t #>
   170   fold_aterms (fn Free (x, _) => Name.declare x | _ => I) t;
   171 
   172 in
   173 
   174 fun declare_type T = map_defaults (fn (types, sorts, used, occ, names) =>
   175  (types,
   176   ins_sorts T sorts,
   177   ins_used T used,
   178   occ,
   179   ins_namesT T names));
   180 
   181 fun declare_syntax t = map_defaults (fn (types, sorts, used, occ, names) =>
   182  (ins_types t types,
   183   fold_types ins_sorts t sorts,
   184   fold_types ins_used t used,
   185   occ,
   186   ins_names t names));
   187 
   188 fun declare_occs t = map_defaults (fn (types, sorts, used, occ, names) =>
   189   (types, sorts, used, ins_occs t occ, names));
   190 
   191 fun declare_term t ctxt =
   192   ctxt
   193   |> declare_syntax t
   194   |> map_defaults (fn (types, sorts, used, occ, names) =>
   195      (ins_skolem (fn x => Vartab.lookup types (x, ~1)) (fixes_of ctxt) types,
   196       sorts,
   197       used,
   198       ins_occs t occ,
   199       ins_names t names));
   200 
   201 fun declare_thm th = fold declare_term (Thm.full_prop_of th :: Thm.hyps_of th);
   202 fun thm_context th = Context.init_proof (Thm.theory_of_thm th) |> declare_thm th;
   203 
   204 end;
   205 
   206 
   207 (* renaming term/type frees *)
   208 
   209 fun rename_wrt ctxt ts frees =
   210   let
   211     val names = #5 (defaults_of (ctxt |> fold declare_syntax ts));
   212     val xs = fst (Name.variants (map #1 frees) names);
   213   in xs ~~ map snd frees end;
   214 
   215 
   216 
   217 (** fixes **)
   218 
   219 local
   220 
   221 fun no_dups [] = ()
   222   | no_dups dups = error ("Duplicate fixed variable(s): " ^ commas_quote dups);
   223 
   224 fun new_fixes xs xs' names' =
   225   map_fixes (fn (fixes, _) => (rev (xs ~~ xs') @ fixes, names')) #>
   226   fold (declare_syntax o Syntax.free) xs' #>
   227   pair xs';
   228 
   229 in
   230 
   231 fun add_fixes xs ctxt =
   232   let
   233     val _ =
   234       (case filter (can Name.dest_skolem) xs of [] => ()
   235       | bads => error ("Illegal internal Skolem constant(s): " ^ commas_quote bads));
   236     val _ = no_dups (duplicates (op =) xs);
   237     val (ys, zs) = split_list (fixes_of ctxt);
   238     val names = fixed_names_of ctxt;
   239     val (xs', names') =
   240       if is_body ctxt then Name.variants (map Name.skolem xs) names
   241       else (no_dups (xs inter_string ys); no_dups (xs inter_string zs);
   242         (xs, fold Name.declare xs names));
   243   in ctxt |> new_fixes xs xs' names' end;
   244 
   245 fun invent_fixes raw_xs ctxt =
   246   let
   247     val names = fixed_names_of ctxt;
   248     val (xs, names') = Name.variants (map Name.clean raw_xs) names;
   249     val xs' = map Name.skolem xs;
   250   in ctxt |> new_fixes xs xs' names' end;
   251 
   252 end;
   253 
   254 fun invent_types Ss ctxt =
   255   let
   256     val tfrees = Name.invents (#5 (defaults_of ctxt)) "'a" (length Ss) ~~ Ss;
   257     val ctxt' = fold (declare_type o TFree) tfrees ctxt;
   258   in (tfrees, ctxt') end;
   259 
   260 
   261 
   262 (** export -- generalize type/term variables **)
   263 
   264 fun export_inst inner outer =
   265   let
   266     val types_outer = used_types outer;
   267     val fixes_inner = fixes_of inner;
   268     val fixes_outer = fixes_of outer;
   269 
   270     val gen_fixes = map #2 (Library.take (length fixes_inner - length fixes_outer, fixes_inner));
   271     val still_fixed = not o member (op =) ("" :: gen_fixes);
   272     val gen_fixesT =
   273       Symtab.fold (fn (a, xs) =>
   274         if member (op =) types_outer a orelse exists still_fixed xs
   275         then I else cons a) (type_occs_of inner) [];
   276   in (gen_fixesT, gen_fixes) end;
   277 
   278 fun exportT_inst inner outer = #1 (export_inst inner outer);
   279 
   280 fun exportT_terms inner outer ts =
   281   map (Term.generalize (exportT_inst (fold declare_occs ts inner) outer, [])
   282     (fold (Term.fold_types Term.maxidx_typ) ts ~1 + 1)) ts;
   283 
   284 fun export_terms inner outer ts =
   285   map (Term.generalize (export_inst (fold declare_occs ts inner) outer)
   286     (fold Term.maxidx_term ts ~1 + 1)) ts;
   287 
   288 fun gen_export inst inner outer ths =
   289   let
   290     val ths' = map Thm.adjust_maxidx_thm ths;
   291     val inner' = fold (declare_occs o Thm.full_prop_of) ths' inner;
   292   in map (Thm.generalize (inst inner' outer) (fold Thm.maxidx_thm ths' ~1 + 1)) ths' end;
   293 
   294 val exportT = gen_export (rpair [] oo exportT_inst);
   295 val export = gen_export export_inst;
   296 
   297 
   298 
   299 (** import -- fix schematic type/term variables **)
   300 
   301 fun importT_inst ts ctxt =
   302   let
   303     val tvars = rev (fold Term.add_tvars ts []);
   304     val (tfrees, ctxt') = invent_types (map #2 tvars) ctxt;
   305   in (tvars ~~ map TFree tfrees, ctxt') end;
   306 
   307 fun import_inst is_open ts ctxt =
   308   let
   309     val (instT, ctxt') = importT_inst ts ctxt;
   310     val vars = map (apsnd (Term.instantiateT instT)) (rev (fold Term.add_vars ts []));
   311     val ren = if is_open then I else Name.internal;
   312     val (xs, ctxt'') = invent_fixes (map (ren o #1 o #1) vars) ctxt';
   313     val inst = vars ~~ map Free (xs ~~ map #2 vars);
   314   in ((instT, inst), ctxt'') end;
   315 
   316 fun importT_terms ts ctxt =
   317   let val (instT, ctxt') = importT_inst ts ctxt
   318   in (map (Term.instantiate (instT, [])) ts, ctxt') end;
   319 
   320 fun import_terms is_open ts ctxt =
   321   let val (inst, ctxt') = import_inst is_open ts ctxt
   322   in (map (Term.instantiate inst) ts, ctxt') end;
   323 
   324 fun importT ths ctxt =
   325   let
   326     val thy = Context.theory_of_proof ctxt;
   327     val certT = Thm.ctyp_of thy;
   328     val (instT, ctxt') = importT_inst (map Thm.full_prop_of ths) ctxt;
   329     val instT' = map (fn (v, T) => (certT (TVar v), certT T)) instT;
   330     val ths' = map (Thm.instantiate (instT', [])) ths;
   331   in (ths', ctxt') end;
   332 
   333 fun import is_open ths ctxt =
   334   let
   335     val thy = Context.theory_of_proof ctxt;
   336     val cert = Thm.cterm_of thy;
   337     val certT = Thm.ctyp_of thy;
   338     val ((instT, inst), ctxt') = import_inst is_open (map Thm.full_prop_of ths) ctxt;
   339     val instT' = map (fn (v, T) => (certT (TVar v), certT T)) instT;
   340     val inst' = map (fn (v, t) => (cert (Var v), cert t)) inst;
   341     val ths' = map (Thm.instantiate (instT', inst')) ths;
   342   in (ths', ctxt') end;
   343 
   344 
   345 (* import/export *)
   346 
   347 fun gen_trade imp exp ctxt f ths =
   348   let val (ths', ctxt') = imp ths ctxt
   349   in exp ctxt' ctxt (f ths') end;
   350 
   351 val tradeT = gen_trade importT exportT;
   352 val trade = gen_trade (import true) export;
   353 
   354 
   355 
   356 (** implicit polymorphism **)
   357 
   358 (* warn_extra_tfrees *)
   359 
   360 fun warn_extra_tfrees ctxt1 ctxt2 =
   361   let
   362     fun occs_typ a = Term.exists_subtype (fn TFree (b, _) => a = b | _ => false);
   363     fun occs_free _ "" = I
   364       | occs_free a x =
   365           (case def_type ctxt1 false (x, ~1) of
   366             SOME T => if occs_typ a T then I else cons (a, x)
   367           | NONE => cons (a, x));
   368 
   369     val occs1 = type_occs_of ctxt1 and occs2 = type_occs_of ctxt2;
   370     val extras = Symtab.fold (fn (a, xs) =>
   371       if Symtab.defined occs1 a then I else fold (occs_free a) xs) occs2 [];
   372     val tfrees = map #1 extras |> sort_distinct string_ord;
   373     val frees = map #2 extras |> sort_distinct string_ord;
   374   in
   375     if null extras then ()
   376     else warning ("Introduced fixed type variable(s): " ^ commas tfrees ^ " in " ^
   377       space_implode " or " (map quote frees))
   378   end;
   379 
   380 
   381 (* monomorphic vs. polymorphic terms *)
   382 
   383 fun monomorphic ctxt ts =
   384   #1 (importT_terms ts (fold declare_term ts ctxt));
   385 
   386 fun polymorphic ctxt ts =
   387   let
   388     val ctxt' = fold declare_term ts ctxt;
   389     val types = subtract (op =) (used_types ctxt) (used_types ctxt');
   390     val idx = fold (Term.fold_types Term.maxidx_typ) ts ~1 + 1;
   391   in map (Term.generalize (types, []) idx) ts end;
   392 
   393 
   394 
   395 (** term bindings **)
   396 
   397 fun hidden_polymorphism t T =
   398   let
   399     val tvarsT = Term.add_tvarsT T [];
   400     val extra_tvars = Term.fold_types (Term.fold_atyps
   401       (fn TVar v => if member (op =) tvarsT v then I else insert (op =) v | _ => I)) t [];
   402   in extra_tvars end;
   403 
   404 fun add_bind (xi, NONE) = map_binds (Vartab.delete_safe xi)
   405   | add_bind ((x, i), SOME t) =
   406       let
   407         val T = Term.fastype_of t;
   408         val t' =
   409           if null (hidden_polymorphism t T) then t
   410           else Var ((x ^ "_has_extra_type_vars_on_rhs", i), T);
   411       in declare_term t' #> map_binds (Vartab.update ((x, i), (T, t'))) end;
   412 
   413 val add_binds = fold add_bind;
   414 
   415 fun expand_binds ctxt =
   416   let
   417     val binds = binds_of ctxt;
   418     fun expand (t as Var (xi, T)) =
   419           (case Vartab.lookup binds xi of
   420             SOME u => Envir.expand_atom T u
   421           | NONE => t)
   422       | expand t = t;
   423   in Envir.beta_norm o Term.map_aterms expand end;
   424 
   425 end;