Fixed type/term variables and polymorphic term abbreviations.
authorwenzelm
Thu Jun 15 23:08:58 2006 +0200 (2006-06-15)
changeset 19899b7385ca02d79
parent 19898 b1d179e42713
child 19900 21a99d88d925
Fixed type/term variables and polymorphic term abbreviations.
src/Pure/variable.ML
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/Pure/variable.ML	Thu Jun 15 23:08:58 2006 +0200
     1.3 @@ -0,0 +1,357 @@
     1.4 +(*  Title:      Pure/variable.ML
     1.5 +    ID:         $Id$
     1.6 +    Author:     Makarius
     1.7 +
     1.8 +Fixed type/term variables and polymorphic term abbreviations.
     1.9 +*)
    1.10 +
    1.11 +signature VARIABLE =
    1.12 +sig
    1.13 +  val is_body: Context.proof -> bool
    1.14 +  val set_body: bool -> Context.proof -> Context.proof
    1.15 +  val restore_body: Context.proof -> Context.proof -> Context.proof
    1.16 +  val fixes_of: Context.proof -> (string * string) list
    1.17 +  val fixed_names_of: Context.proof -> string list
    1.18 +  val binds_of: Context.proof -> (typ * term) Vartab.table
    1.19 +  val defaults_of: Context.proof ->
    1.20 +    typ Vartab.table * sort Vartab.table * string list * term list Symtab.table
    1.21 +  val used_types: Context.proof -> string list
    1.22 +  val is_declared: Context.proof -> string -> bool
    1.23 +  val is_fixed: Context.proof -> string -> bool
    1.24 +  val def_sort: Context.proof -> indexname -> sort option
    1.25 +  val def_type: Context.proof -> bool -> indexname -> typ option
    1.26 +  val default_type: Context.proof -> string -> typ option
    1.27 +  val declare_type: typ -> Context.proof -> Context.proof
    1.28 +  val declare_syntax: term -> Context.proof -> Context.proof
    1.29 +  val declare_term: term -> Context.proof -> Context.proof
    1.30 +  val invent_types: sort list -> Context.proof -> (string * sort) list * Context.proof
    1.31 +  val rename_wrt: Context.proof -> term list -> (string * 'a) list -> (string * 'a) list
    1.32 +  val warn_extra_tfrees: Context.proof -> Context.proof -> unit
    1.33 +  val generalize_tfrees: Context.proof -> Context.proof -> string list -> string list
    1.34 +  val generalize: Context.proof -> Context.proof -> term list -> term list
    1.35 +  val polymorphic: Context.proof -> term list -> term list
    1.36 +  val hidden_polymorphism: term -> typ -> (indexname * sort) list
    1.37 +  val monomorphic_inst: term list -> Context.proof ->
    1.38 +    ((indexname * sort) * typ) list * Context.proof
    1.39 +  val monomorphic: Context.proof -> term list -> term list
    1.40 +  val add_binds: (indexname * term option) list -> Context.proof -> Context.proof
    1.41 +  val expand_binds: Context.proof -> term -> term
    1.42 +  val add_fixes: string list -> Context.proof -> string list * Context.proof
    1.43 +  val invent_fixes: string list -> Context.proof -> string list * Context.proof
    1.44 +  val import_types: bool -> typ list -> Context.proof -> typ list * Context.proof
    1.45 +  val import_terms: bool -> term list -> Context.proof -> term list * Context.proof
    1.46 +  val import: bool -> thm list -> Context.proof -> thm list * Context.proof
    1.47 +end;
    1.48 +
    1.49 +structure Variable: VARIABLE =
    1.50 +struct
    1.51 +
    1.52 +(** local context data **)
    1.53 +
    1.54 +datatype data = Data of
    1.55 + {is_body: bool,                          (*internal body mode*)
    1.56 +  fixes: (string * string) list,          (*term fixes*)
    1.57 +  binds: (typ * term) Vartab.table,       (*term bindings*)
    1.58 +  defaults:
    1.59 +    typ Vartab.table *                    (*type constraints*)
    1.60 +    sort Vartab.table *                   (*default sorts*)
    1.61 +    string list *                         (*used type variables*)
    1.62 +    term list Symtab.table};              (*type variable occurrences*)
    1.63 +
    1.64 +fun make_data (is_body, fixes, binds, defaults) =
    1.65 +  Data {is_body = is_body, fixes = fixes, binds = binds, defaults = defaults};
    1.66 +
    1.67 +structure Data = ProofDataFun
    1.68 +(
    1.69 +  val name = "Pure/variable";
    1.70 +  type T = data;
    1.71 +  fun init thy =
    1.72 +    make_data (false, [], Vartab.empty, (Vartab.empty, Vartab.empty, [], Symtab.empty));
    1.73 +  fun print _ _ = ();
    1.74 +);
    1.75 +
    1.76 +val _ = Context.add_setup Data.init;
    1.77 +
    1.78 +fun map_data f =
    1.79 +  Data.map (fn Data {is_body, fixes, binds, defaults} =>
    1.80 +    make_data (f (is_body, fixes, binds, defaults)));
    1.81 +
    1.82 +fun map_fixes f = map_data (fn (is_body, fixes, binds, defaults) =>
    1.83 +  (is_body, f fixes, binds, defaults));
    1.84 +
    1.85 +fun map_binds f = map_data (fn (is_body, fixes, binds, defaults) =>
    1.86 +  (is_body, fixes, f binds, defaults));
    1.87 +
    1.88 +fun map_defaults f = map_data (fn (is_body, fixes, binds, defaults) =>
    1.89 +  (is_body, fixes, binds, f defaults));
    1.90 +
    1.91 +fun rep_data ctxt = Data.get ctxt |> (fn Data args => args);
    1.92 +
    1.93 +val is_body = #is_body o rep_data;
    1.94 +fun set_body b = map_data (fn (_, fixes, binds, defaults) => (b, fixes, binds, defaults));
    1.95 +fun restore_body ctxt = set_body (is_body ctxt);
    1.96 +
    1.97 +val fixes_of = #fixes o rep_data;
    1.98 +val fixed_names_of = map #2 o fixes_of;
    1.99 +
   1.100 +val binds_of = #binds o rep_data;
   1.101 +
   1.102 +val defaults_of = #defaults o rep_data;
   1.103 +val used_types = #3 o defaults_of;
   1.104 +val type_occs_of = #4 o defaults_of;
   1.105 +
   1.106 +fun is_declared ctxt x = Vartab.defined (#1 (defaults_of ctxt)) (x, ~1);
   1.107 +fun is_fixed ctxt x = exists (fn (_, y) => x = y) (fixes_of ctxt);
   1.108 +
   1.109 +
   1.110 +
   1.111 +(** declarations **)
   1.112 +
   1.113 +(* default sorts and types *)
   1.114 +
   1.115 +val def_sort = Vartab.lookup o #2 o defaults_of;
   1.116 +
   1.117 +fun def_type ctxt pattern xi =
   1.118 +  let val {binds, defaults = (types, _, _, _), ...} = rep_data ctxt in
   1.119 +    (case Vartab.lookup types xi of
   1.120 +      NONE =>
   1.121 +        if pattern then NONE
   1.122 +        else Vartab.lookup binds xi |> Option.map (TypeInfer.polymorphicT o #1)
   1.123 +    | some => some)
   1.124 +  end;
   1.125 +
   1.126 +fun default_type ctxt x = Vartab.lookup (#1 (defaults_of ctxt)) (x, ~1);
   1.127 +
   1.128 +
   1.129 +(* declare types/terms *)
   1.130 +
   1.131 +local
   1.132 +
   1.133 +val ins_types = fold_aterms
   1.134 +  (fn Free (x, T) => Vartab.update ((x, ~1), T)
   1.135 +    | Var v => Vartab.update v
   1.136 +    | _ => I);
   1.137 +
   1.138 +val ins_sorts = fold_atyps
   1.139 +  (fn TFree (x, S) => Vartab.update ((x, ~1), S)
   1.140 +    | TVar v => Vartab.update v
   1.141 +    | _ => I);
   1.142 +
   1.143 +val ins_used = fold_atyps
   1.144 +  (fn TFree (x, _) => insert (op =) x | _ => I);
   1.145 +
   1.146 +val ins_occs = fold_term_types (fn t =>
   1.147 +  fold_atyps (fn TFree (x, _) => Symtab.update_list (x, t) | _ => I));
   1.148 +
   1.149 +fun ins_skolem def_ty = fold_rev (fn (x, x') =>
   1.150 +  (case def_ty x' of
   1.151 +    SOME T => Vartab.update ((x, ~1), T)
   1.152 +  | NONE => I));
   1.153 +
   1.154 +in
   1.155 +
   1.156 +fun declare_type T = map_defaults (fn (types, sorts, used, occ) =>
   1.157 + (types,
   1.158 +  ins_sorts T sorts,
   1.159 +  ins_used T used,
   1.160 +  occ));
   1.161 +
   1.162 +fun declare_syntax t = map_defaults (fn (types, sorts, used, occ) =>
   1.163 + (ins_types t types,
   1.164 +  fold_types ins_sorts t sorts,
   1.165 +  fold_types ins_used t used,
   1.166 +  occ));
   1.167 +
   1.168 +fun declare_term t ctxt =
   1.169 +  ctxt
   1.170 +  |> declare_syntax t
   1.171 +  |> map_defaults (fn (types, sorts, used, occ) =>
   1.172 +     (ins_skolem (fn x => Vartab.lookup types (x, ~1)) (fixes_of ctxt) types,
   1.173 +      sorts,
   1.174 +      used,
   1.175 +      ins_occs t occ));
   1.176 +
   1.177 +end;
   1.178 +
   1.179 +
   1.180 +(* invent types *)
   1.181 +
   1.182 +fun invent_types Ss ctxt =
   1.183 +  let
   1.184 +    val tfrees = Term.invent_names (used_types ctxt) "'a" (length Ss) ~~ Ss;
   1.185 +    val ctxt' = fold (declare_type o TFree) tfrees ctxt;
   1.186 +  in (tfrees, ctxt') end;
   1.187 +
   1.188 +
   1.189 +(* renaming term/type frees *)
   1.190 +
   1.191 +fun rename_wrt ctxt ts frees =
   1.192 +  let
   1.193 +    val (types, sorts, _, _) = defaults_of (ctxt |> fold declare_syntax ts);
   1.194 +    fun ren (x, X) xs =
   1.195 +      let
   1.196 +        fun used y = y = "" orelse y = "'" orelse member (op =) xs y orelse
   1.197 +          Vartab.defined types (y, ~1) orelse Vartab.defined sorts (y, ~1);
   1.198 +        val x' = Term.variant_name used x;
   1.199 +      in ((x', X), x' :: xs) end;
   1.200 +  in #1 (fold_map ren frees []) end;
   1.201 +
   1.202 +
   1.203 +
   1.204 +(** Hindley-Milner polymorphism **)
   1.205 +
   1.206 +(* warn_extra_tfrees *)
   1.207 +
   1.208 +fun warn_extra_tfrees ctxt1 ctxt2 =
   1.209 +  let
   1.210 +    fun occs_typ a (Type (_, Ts)) = exists (occs_typ a) Ts
   1.211 +      | occs_typ a (TFree (b, _)) = a = b
   1.212 +      | occs_typ _ (TVar _) = false;
   1.213 +    fun occs_free a (Free (x, _)) =
   1.214 +          (case def_type ctxt1 false (x, ~1) of
   1.215 +            SOME T => if occs_typ a T then I else cons (a, x)
   1.216 +          | NONE => cons (a, x))
   1.217 +      | occs_free _ _ = I;
   1.218 +
   1.219 +    val occs1 = type_occs_of ctxt1 and occs2 = type_occs_of ctxt2;
   1.220 +    val extras = Symtab.fold (fn (a, ts) =>
   1.221 +      if Symtab.defined occs1 a then I else fold (occs_free a) ts) occs2 [];
   1.222 +    val tfrees = map #1 extras |> sort_distinct string_ord;
   1.223 +    val frees = map #2 extras |> sort_distinct string_ord;
   1.224 +  in
   1.225 +    if null extras then ()
   1.226 +    else warning ("Introduced fixed type variable(s): " ^ commas tfrees ^ " in " ^
   1.227 +      space_implode " or " (map quote frees))
   1.228 +  end;
   1.229 +
   1.230 +
   1.231 +(* generalize type variables *)
   1.232 +
   1.233 +fun generalize_tfrees inner outer =
   1.234 +  let
   1.235 +    val extra_fixes = subtract (op =) (fixed_names_of outer) (fixed_names_of inner);
   1.236 +    fun still_fixed (Free (x, _)) = not (member (op =) extra_fixes x)
   1.237 +      | still_fixed _ = false;
   1.238 +    val occs_inner = type_occs_of inner;
   1.239 +    val occs_outer = type_occs_of outer;
   1.240 +    fun add a gen =
   1.241 +      if Symtab.defined occs_outer a orelse
   1.242 +        exists still_fixed (Symtab.lookup_list occs_inner a)
   1.243 +      then gen else a :: gen;
   1.244 +  in fn tfrees => fold add tfrees [] end;
   1.245 +
   1.246 +fun generalize inner outer ts =
   1.247 +  let
   1.248 +    val tfrees = generalize_tfrees inner outer (map #1 (fold Term.add_tfrees ts []));
   1.249 +    fun gen (x, S) = if member (op =) tfrees x then TVar ((x, 0), S) else TFree (x, S);
   1.250 +  in map (Term.map_term_types (Term.map_type_tfree gen)) ts end;
   1.251 +
   1.252 +fun polymorphic ctxt ts =
   1.253 +  generalize (fold declare_term ts ctxt) ctxt ts;
   1.254 +
   1.255 +fun hidden_polymorphism t T =
   1.256 +  let
   1.257 +    val tvarsT = Term.add_tvarsT T [];
   1.258 +    val extra_tvars = Term.fold_types (Term.fold_atyps
   1.259 +      (fn TVar v => if member (op =) tvarsT v then I else insert (op =) v | _ => I)) t [];
   1.260 +  in extra_tvars end;
   1.261 +
   1.262 +
   1.263 +(* monomorphic -- fixes type variables *)
   1.264 +
   1.265 +fun monomorphic_inst ts ctxt =
   1.266 +  let
   1.267 +    val tvars = rev (fold Term.add_tvars ts []);
   1.268 +    val (tfrees, ctxt') = invent_types (map #2 tvars) ctxt;
   1.269 +  in (tvars ~~ map TFree tfrees, ctxt') end;
   1.270 +
   1.271 +fun monomorphic ctxt ts =
   1.272 +  map (Term.instantiate (#1 (monomorphic_inst ts (fold declare_term ts ctxt)), [])) ts;
   1.273 +
   1.274 +
   1.275 +
   1.276 +(** term abbreviations **)
   1.277 +
   1.278 +fun add_bind (xi, NONE) = map_binds (Vartab.delete_safe xi)
   1.279 +  | add_bind ((x, i), SOME t) =
   1.280 +      let
   1.281 +        val T = Term.fastype_of t;
   1.282 +        val t' =
   1.283 +          if null (hidden_polymorphism t T) then t
   1.284 +          else Var ((x ^ "_has_extra_type_vars_on_rhs", i), T);
   1.285 +      in declare_term t' #> map_binds (Vartab.update ((x, i), (T, t'))) end;
   1.286 +
   1.287 +val add_binds = fold add_bind;
   1.288 +
   1.289 +fun expand_binds ctxt =
   1.290 +  let
   1.291 +    val binds = binds_of ctxt;
   1.292 +    fun expand (t as Var (xi, T)) =
   1.293 +          (case Vartab.lookup binds xi of
   1.294 +            SOME u => Envir.expand_atom T u
   1.295 +          | NONE => t)
   1.296 +      | expand t = t;
   1.297 +  in Envir.beta_norm o Term.map_aterms expand end;
   1.298 +
   1.299 +
   1.300 +
   1.301 +(** fixes **)
   1.302 +
   1.303 +fun no_dups [] = ()
   1.304 +  | no_dups dups = error ("Duplicate fixed variable(s): " ^ commas_quote dups);
   1.305 +
   1.306 +fun add_fixes xs ctxt =
   1.307 +  let
   1.308 +    val (ys, zs) = split_list (fixes_of ctxt);
   1.309 +    val _ = no_dups (duplicates (op =) xs);
   1.310 +    val _ =
   1.311 +      (case filter (can Syntax.dest_skolem) xs of [] => ()
   1.312 +      | bads => error ("Illegal internal Skolem constant(s): " ^ commas_quote bads));
   1.313 +    val xs' =
   1.314 +      if is_body ctxt then Term.variantlist (map Syntax.skolem xs, zs)
   1.315 +      else (no_dups (xs inter_string ys); no_dups (xs inter_string zs); xs);
   1.316 +  in
   1.317 +    ctxt
   1.318 +    |> map_fixes (fn fixes => rev (xs ~~ xs') @ fixes)
   1.319 +    |> fold (declare_syntax o Syntax.free) xs'
   1.320 +    |> pair xs'
   1.321 +  end;
   1.322 +
   1.323 +fun invent_fixes xs ctxt =
   1.324 +  ctxt
   1.325 +  |> set_body true
   1.326 +  |> add_fixes (Term.variantlist (xs, []))
   1.327 +  ||> restore_body ctxt;
   1.328 +
   1.329 +
   1.330 +(* import -- fixes schematic variables *)
   1.331 +
   1.332 +fun import_inst is_open ts ctxt =
   1.333 +  let
   1.334 +    val (instT, ctxt') = monomorphic_inst ts ctxt;
   1.335 +    val vars = map (apsnd (Term.instantiateT instT)) (rev (fold Term.add_vars ts []));
   1.336 +    val ren = if is_open then I else Syntax.internal;
   1.337 +    val (xs, ctxt'') = invent_fixes (map (ren o #1 o #1) vars) ctxt';
   1.338 +    val inst = vars ~~ map Free (xs ~~ map #2 vars);
   1.339 +  in ((instT, inst), ctxt'') end;
   1.340 +
   1.341 +fun import_terms is_open ts ctxt =
   1.342 +  let val (inst, ctxt') = import_inst is_open ts ctxt
   1.343 +  in (map (Term.instantiate inst) ts, ctxt') end;
   1.344 +
   1.345 +fun import_types is_open Ts ctxt =
   1.346 +  import_terms is_open (map Logic.mk_type Ts) ctxt
   1.347 +  |>> map Logic.dest_type;
   1.348 +
   1.349 +fun import is_open ths ctxt =
   1.350 +  let
   1.351 +    val thy = Context.theory_of_proof ctxt;
   1.352 +    val cert = Thm.cterm_of thy;
   1.353 +    val certT = Thm.ctyp_of thy;
   1.354 +    val ((instT, inst), ctxt') = import_inst is_open (map Thm.full_prop_of ths) ctxt;
   1.355 +    val instT' = map (fn (v, T) => (certT (TVar v), certT T)) instT;
   1.356 +    val inst' = map (fn (v, t) => (cert (Var v), cert t)) inst;
   1.357 +    val ths' = map (Thm.instantiate (instT', inst')) ths;
   1.358 +  in (ths', ctxt') end;
   1.359 +
   1.360 +end;