Changed treatment of during type inference internally generated type
authornipkow
Mon, 13 Mar 1995 09:38:10 +0100
changeset 949 83c588d6fee9
parent 948 3647161d15d3
child 950 323f8ca4587a
Changed treatment of during type inference internally generated type variables. 1. They are renamed to 'a, 'b, 'c etc away from a given set of used names. 2. They are either frozen (turned into TFrees) or left schematic (TVars) depending on a parameter. In goals they are frozen, for instantiations they are left schematic.
src/Pure/drule.ML
src/Pure/sign.ML
src/Pure/tactic.ML
src/Pure/term.ML
src/Pure/thm.ML
src/Pure/type.ML
--- a/src/Pure/drule.ML	Sat Mar 11 17:46:14 1995 +0100
+++ b/src/Pure/drule.ML	Mon Mar 13 09:38:10 1995 +0100
@@ -57,7 +57,7 @@
   val read_insts	:
           Sign.sg -> (indexname -> typ option) * (indexname -> sort option)
                   -> (indexname -> typ option) * (indexname -> sort option)
-                  -> (string*string)list
+                  -> string list -> (string*string)list
                   -> (indexname*ctyp)list * (cterm*cterm)list
   val reflexive_thm	: thm
   val revcut_rl		: thm
@@ -249,7 +249,7 @@
 fun inst_failure ixn =
   error("Instantiation of " ^ Syntax.string_of_vname ixn ^ " fails");
 
-fun read_insts sign (rtypes,rsorts) (types,sorts) insts =
+fun read_insts sign (rtypes,rsorts) (types,sorts) used insts =
 let val {tsig,...} = Sign.rep_sg sign
     fun split([],tvs,vs) = (tvs,vs)
       | split((sv,st)::l,tvs,vs) = (case explode sv of
@@ -264,14 +264,15 @@
            else inst_failure ixn
         end
     val tye = map readT tvs;
-    fun add_cterm ((cts,tye), (ixn,st)) =
+    fun add_cterm ((cts,tye,used), (ixn,st)) =
         let val T = case rtypes ixn of
                       Some T => typ_subst_TVars tye T
                     | None => absent ixn;
-            val (ct,tye2) = read_def_cterm (sign,types,sorts) (st,T);
+            val (ct,tye2) = read_def_cterm(sign,types,sorts) used false (st,T);
             val cv = cterm_of sign (Var(ixn,typ_subst_TVars tye2 T))
-        in ((cv,ct)::cts,tye2 @ tye) end
-    val (cterms,tye') = foldl add_cterm (([],tye), vs);
+            val used' = add_term_tvarnames(term_of ct,used);
+        in ((cv,ct)::cts,tye2 @ tye,used') end
+    val (cterms,tye',_) = foldl add_cterm (([],tye,used), vs);
 in (map (fn (ixn,T) => (ixn,ctyp_of sign T)) tye', cterms) end;
 
 
@@ -584,7 +585,7 @@
 (*Instantiate theorem th, reading instantiations under signature sg*)
 fun read_instantiate_sg sg sinsts th =
     let val ts = types_sorts th;
-    in  instantiate (read_insts sg ts ts sinsts) th  end;
+    in  instantiate (read_insts sg ts ts [] sinsts) th  end;
 
 (*Instantiate theorem th, reading instantiations under theory of th*)
 fun read_instantiate sinsts th =
--- a/src/Pure/sign.ML	Sat Mar 11 17:46:14 1995 +0100
+++ b/src/Pure/sign.ML	Mon Mar 13 09:38:10 1995 +0100
@@ -39,8 +39,8 @@
     val certify_term: sg -> term -> term * typ * int
     val read_typ: sg * (indexname -> sort option) -> string -> typ
     val infer_types: sg -> (indexname -> typ option) ->
-      (indexname -> sort option) -> bool -> term list * typ ->
-      int * term * (indexname * typ) list
+      (indexname -> sort option) -> string list -> bool -> bool
+      -> term list * typ -> int * term * (indexname * typ) list
     val add_classes: (class * class list) list -> sg -> sg
     val add_classrel: (class * class) list -> sg -> sg
     val add_defsort: sort -> sg -> sg
@@ -252,7 +252,7 @@
 
 (** infer_types **)         (*exception ERROR*)
 
-fun infer_types sg types sorts print_msg (ts, T) =
+fun infer_types sg types sorts used freeze print_msg (ts, T) =
   let
     val Sg {tsig, ...} = sg;
     val show_typ = string_of_typ sg;
@@ -268,16 +268,16 @@
 	cat_lines (map show_typ Ts) ^ term_err ts ^ "\n";
 
     val T' = certify_typ sg T
-      handle TYPE arg => error (exn_type_msg arg);
+             handle TYPE arg => error (exn_type_msg arg);
 
     val ct = const_type sg;
 
     fun process_terms (t::ts) (idx, infrd_t, tye) msg n =
-         let fun mk_some (x, y) = (Some x, Some y);
-
-             val ((infrd_t', tye'), msg') = 
-              (mk_some (Type.infer_types (tsig, ct, types, sorts, T', t)), msg)
-              handle TYPE arg => ((None, None), exn_type_msg arg)
+         let val (infrd_t', tye', msg') = 
+              let val (T,tye) =
+                    Type.infer_types(tsig,ct,types,sorts,used,freeze,T',t)
+              in (Some T, Some tye, msg) end
+              handle TYPE arg => (None, None, exn_type_msg arg)
 
              val old_show_brackets = !show_brackets;
 
@@ -291,8 +291,8 @@
                 (show_term (the infrd_t)) else msg') ^ "\n" ^ 
                 (show_term (the infrd_t')) ^ "\n";
 
-             val _ = (show_brackets := old_show_brackets);
-         in if is_none infrd_t' then
+         in show_brackets := old_show_brackets;
+            if is_none infrd_t' then
               process_terms ts (idx, infrd_t, tye) msg'' (n+1)
             else
               process_terms ts (Some n, infrd_t', tye') msg'' (n+1)
--- a/src/Pure/tactic.ML	Sat Mar 11 17:46:14 1995 +0100
+++ b/src/Pure/tactic.ML	Mon Mar 13 09:38:10 1995 +0100
@@ -197,9 +197,11 @@
     val rts = types_sorts rule and (types,sorts) = types_sorts state
     fun types'(a,~1) = (case assoc(params,a) of None => types(a,~1) | sm => sm)
       | types'(ixn) = types ixn;
-    val (Tinsts,insts) = read_insts sign rts (types',sorts) sinsts
+    val used = add_term_tvarnames
+                  (#prop(rep_thm state) $ #prop(rep_thm rule),[])
+    val (Tinsts,insts) = read_insts sign rts (types',sorts) used sinsts
 in instantiate (map lifttvar Tinsts, map liftpair insts)
-		(lift_rule (state,i) rule)
+               (lift_rule (state,i) rule)
 end;
 
 
--- a/src/Pure/term.ML	Sat Mar 11 17:46:14 1995 +0100
+++ b/src/Pure/term.ML	Mon Mar 13 09:38:10 1995 +0100
@@ -189,7 +189,14 @@
   | size_of_term (f$t) = size_of_term f  +  size_of_term t
   | size_of_term _ = 1;
 
- 
+fun map_type_tvar f (Type(a,Ts)) = Type(a, map (map_type_tvar f) Ts)
+  | map_type_tvar f (T as TFree _) = T
+  | map_type_tvar f (TVar x) = f x;
+
+fun map_type_tfree f (Type(a,Ts)) = Type(a, map (map_type_tfree f) Ts)
+  | map_type_tfree f (TFree x) = f x
+  | map_type_tfree f (T as TVar _) = T;
+
 (* apply a function to all types in a term *)
 fun map_term_types f =
 let fun map(Const(a,T)) = Const(a, f T)
@@ -432,9 +439,7 @@
 
 
 (* Increment the index of all Poly's in T by k *)
-fun incr_tvar k (Type(a,Ts)) = Type(a, map (incr_tvar k) Ts)
-  | incr_tvar k (T as TFree _) = T
-  | incr_tvar k (TVar((a,i),rs)) = TVar((a,i+k),rs);
+fun incr_tvar k = map_type_tvar (fn ((a,i),S) => TVar((a,i+k),S));
 
 
 (**** Syntax-related declarations ****)
@@ -500,20 +505,40 @@
   | add_typ_tfrees(TFree(f),fs) = f ins fs
   | add_typ_tfrees(TVar(_),fs) = fs;
 
+fun add_typ_varnames(Type(_,Ts),nms) = foldr add_typ_varnames (Ts,nms)
+  | add_typ_varnames(TFree(nm,_),nms) = nm ins nms
+  | add_typ_varnames(TVar((nm,_),_),nms) = nm ins nms;
+
 (*Accumulates the TVars in a term, suppressing duplicates. *)
 val add_term_tvars = it_term_types add_typ_tvars;
-val add_term_tvar_ixns = (map #1) o (it_term_types add_typ_tvars);
 
 (*Accumulates the TFrees in a term, suppressing duplicates. *)
 val add_term_tfrees = it_term_types add_typ_tfrees;
 val add_term_tfree_names = it_term_types add_typ_tfree_names;
 
+val add_term_tvarnames = it_term_types add_typ_varnames;
+
 (*Non-list versions*)
 fun typ_tfrees T = add_typ_tfrees(T,[]);
 fun typ_tvars T = add_typ_tvars(T,[]);
 fun term_tfrees t = add_term_tfrees(t,[]);
 fun term_tvars t = add_term_tvars(t,[]);
 
+(*special code to enforce left-to-right collection of TVar-indexnames*)
+
+fun add_typ_ixns(ixns,Type(_,Ts)) = foldl add_typ_ixns (ixns,Ts)
+  | add_typ_ixns(ixns,TVar(ixn,_)) = if ixn mem ixns then ixns else ixns@[ixn]
+  | add_typ_ixns(ixns,TFree(_)) = ixns;
+
+fun add_term_tvar_ixns(Const(_,T),ixns) = add_typ_ixns(ixns,T)
+  | add_term_tvar_ixns(Free(_,T),ixns) = add_typ_ixns(ixns,T)
+  | add_term_tvar_ixns(Var(_,T),ixns) = add_typ_ixns(ixns,T)
+  | add_term_tvar_ixns(Bound _,ixns) = ixns
+  | add_term_tvar_ixns(Abs(_,T,t),ixns) =
+      add_term_tvar_ixns(t,add_typ_ixns(ixns,T))
+  | add_term_tvar_ixns(f$t,ixns) =
+      add_term_tvar_ixns(t,add_term_tvar_ixns(f,ixns));
+
 (** Frees and Vars **)
 
 (*a partial ordering (not reflexive) for atomic terms*)
--- a/src/Pure/thm.ML	Sat Mar 11 17:46:14 1995 +0100
+++ b/src/Pure/thm.ML	Mon Mar 13 09:38:10 1995 +0100
@@ -106,7 +106,7 @@
   val cpure_thy		: theory
   val read_def_cterm 	:
          Sign.sg * (indexname -> typ option) * (indexname -> sort option) ->
-         string * typ -> cterm * (indexname * typ) list
+         string list -> bool -> string * typ -> cterm * (indexname * typ) list
    val reflexive	: cterm -> thm
   val rename_params_rule: string list * int -> thm -> thm
   val rep_thm		: thm -> {prop: term, hyps: term list, 
@@ -193,17 +193,18 @@
 (** read cterms **)   (*exception ERROR*)
 
 (*read term, infer types, certify term*)
-fun read_def_cterm (sign, types, sorts) (a, T) =
+fun read_def_cterm (sign, types, sorts) used freeze (a, T) =
   let
     val T' = Sign.certify_typ sign T
       handle TYPE (msg, _, _) => error msg;
     val ts = Syntax.read (#syn (Sign.rep_sg sign)) T' a;
-    val (_, t', tye) = Sign.infer_types sign types sorts true (ts, T');
+    val (_, t', tye) =
+          Sign.infer_types sign types sorts used freeze true (ts, T');
     val ct = cterm_of sign t'
       handle TERM (msg, _) => error msg;
   in (ct, tye) end;
 
-fun read_cterm sign = #1 o read_def_cterm (sign, K None, K None);
+fun read_cterm sign = #1 o read_def_cterm (sign, K None, K None) [] true;
 
 
 
@@ -371,8 +372,10 @@
     handle ERROR => err_in_axm name;
 
 fun inferT_axm sg (name, pre_tm) =
- (name, no_vars (#2 (Sign.infer_types sg (K None) (K None) true ([pre_tm], propT))))
-    handle ERROR => err_in_axm name;
+  let val t = #2(Sign.infer_types sg (K None) (K None) [] true true
+                                     ([pre_tm], propT))
+  in  (name, no_vars t) end
+  handle ERROR => err_in_axm name;
 
 
 (* extend axioms of a theory *)
@@ -758,7 +761,7 @@
 
 (* Replace all TVars by new TFrees *)
 fun freezeT(Thm{sign,maxidx,hyps,prop}) =
-  let val prop' = Type.freeze (K true) prop
+  let val prop' = Type.freeze prop
   in Thm{sign=sign, maxidx=maxidx_of_term prop', hyps=hyps, prop=prop'} end;
 
 
--- a/src/Pure/type.ML	Sat Mar 11 17:46:14 1995 +0100
+++ b/src/Pure/type.ML	Mon Mar 13 09:38:10 1995 +0100
@@ -41,10 +41,12 @@
   val rem_sorts: typ -> typ
   val cert_typ: type_sig -> typ -> typ
   val norm_typ: type_sig -> typ -> typ
-  val freeze: (indexname -> bool) -> term -> term
+  val freeze: term -> term
   val freeze_vars: typ -> typ
-  val infer_types: type_sig * (string -> typ option) * (indexname -> typ option) *
-    (indexname -> sort option) * typ * term -> term * (indexname * typ) list
+  val infer_types: type_sig * (string -> typ option) *
+                   (indexname -> typ option) * (indexname -> sort option) *
+                   string list * bool * typ * term
+                   -> term * (indexname * typ) list
   val inst_term_tvars: type_sig * (indexname * typ) list -> term -> term
   val thaw_vars: typ -> typ
   val typ_errors: type_sig -> typ * string list -> string list
@@ -72,9 +74,7 @@
   else raise_type "Illegal schematic type variable(s)" [T] [];
 
 (*turn TFrees into TVars to allow types & axioms to be written without "?"*)
-fun varifyT (Type (a, Ts)) = Type (a, map varifyT Ts)
-  | varifyT (TFree (a, S)) = TVar ((a, 0), S)
-  | varifyT T = T;
+val varifyT = map_type_tfree (fn (a, S) => TVar((a, 0), S));
 
 (*inverse of varifyT*)
 fun unvarifyT (Type (a, Ts)) = Type (a, map unvarifyT Ts)
@@ -87,13 +87,10 @@
     val fs = add_term_tfree_names (t, []) \\ fixed;
     val ixns = add_term_tvar_ixns (t, []);
     val fmap = fs ~~ variantlist (fs, map #1 ixns)
-    fun thaw (Type(a, Ts)) = Type (a, map thaw Ts)
-      | thaw (T as TVar _) = T
-      | thaw (T as TFree(a, S)) =
-          (case assoc (fmap, a) of None => T | Some b => TVar((b, 0), S))
-  in
-    map_term_types thaw t
-  end;
+    fun thaw(f as (a,S)) = case assoc (fmap, a) of
+                             None => TFree(f)
+                           | Some b => TVar((b, 0), S)
+  in  map_term_types (map_type_tfree thaw) t  end;
 
 
 
@@ -298,11 +295,10 @@
 (*Instantiation of type variables in types*)
 (*Pre: instantiations obey restrictions! *)
 fun inst_typ tye =
-  let fun inst(Type(a, Ts)) = Type(a, map inst Ts)
-        | inst(T as TFree _) = T
-        | inst(T as TVar(v, _)) =
-            (case assoc(tye, v) of Some U => inst U | None => T)
-  in inst end;
+  let fun inst(var as (v, _)) = case assoc(tye, v) of
+                                  Some U => inst_typ tye U
+                                | None => TVar(var)
+  in map_type_tvar inst end;
 
 (* 'least_sort' returns for a given type its maximum sort:
    - type variables, free types: the sort brought with
@@ -327,11 +323,10 @@
 
 (*Instantiation of type variables in types *)
 fun inst_typ_tvars(tsig, tye) =
-  let fun inst(Type(a, Ts)) = Type(a, map inst Ts)
-        | inst(T as TFree _) = T
-        | inst(T as TVar(v, S)) = (case assoc(tye, v) of
-                None => T | Some(U) => (check_has_sort(tsig, U, S); U))
-  in inst end;
+  let fun inst(var as (v, S)) = case assoc(tye, v) of
+              Some U => (check_has_sort(tsig, U, S); U)
+            | None => TVar(var)
+  in map_type_tvar inst end;
 
 (*Instantiation of type variables in terms *)
 fun inst_term_tvars(tsig, tye) = map_term_types (inst_typ_tvars(tsig, tye));
@@ -927,9 +922,8 @@
            end
   in inf end;
 
-fun freeze_vars(Type(a, Ts)) = Type(a, map freeze_vars Ts)
-  | freeze_vars(T as TFree _) = T
-  | freeze_vars(TVar(v, S)) = TFree(Syntax.string_of_vname v, S);
+val freeze_vars =
+      map_type_tvar (fn (v, S) => TFree(Syntax.string_of_vname v, S));
 
 (* Attach a type to a constant *)
 fun type_const (a, T) = Const(a, incr_tvar (new_tvar_inx()) T);
@@ -1013,25 +1007,42 @@
   | unconstrain (f$t) = unconstrain f $ unconstrain t
   | unconstrain (t) = t;
 
+fun nextname(pref,c) = if c="z" then (pref^"a", "a") else (pref,chr(ord(c)+1));
 
-(* Turn all TVars which satisfy p into new TFrees *)
-fun freeze p t =
-  let val fs = add_term_tfree_names(t, []);
-      val inxs = filter p (add_term_tvar_ixns(t, []));
-      val vmap = inxs ~~ variantlist(map #1 inxs, fs);
-      fun free(Type(a, Ts)) = Type(a, map free Ts)
-        | free(T as TVar(v, S)) =
-            (case assoc(vmap, v) of None => T | Some(a) => TFree(a, S))
-        | free(T as TFree _) = T
-  in map_term_types free t end;
+fun newtvars used =
+  let fun new([],_,vmap) = vmap
+        | new(ixn::ixns,p as (pref,c),vmap) =
+            let val nm = pref ^ c
+            in if nm mem used then new(ixn::ixns,nextname p, vmap)
+               else new(ixns, nextname p, (ixn,nm)::vmap)
+            end
+  in new end;
+
+(*
+Turn all TVars which satisfy p into new (if freeze then TFrees else TVars).
+Note that if t contains frozen TVars there is the possibility that a TVar is
+turned into one of those. This is sound but not complete.
+*)
+fun convert used freeze p t =
+  let val used = if freeze then add_term_tfree_names(t, used)
+                 else used union
+                      (map #1 (filter_out p (add_term_tvar_ixns(t, []))))
+      val ixns = filter p (add_term_tvar_ixns(t, []));
+      val vmap = newtvars used (ixns,("'","a"),[]);
+      fun conv(var as (ixn,S)) = case assoc(vmap,ixn) of
+            None => TVar(var) |
+            Some(a) => if freeze then TFree(a,S) else TVar((a,0),S);
+  in map_term_types (map_type_tvar conv) t end;
+
+fun freeze t = convert (add_term_tfree_names(t,[])) true (K true) t;
 
 (* Thaw all TVars that were frozen in freeze_vars *)
-fun thaw_vars(Type(a, Ts)) = Type(a, map thaw_vars Ts)
-  | thaw_vars(T as TFree(a, S)) = (case explode a of
+val thaw_vars =
+  let fun thaw(f as (a, S)) = (case explode a of
           "?"::"'"::vn => let val ((b, i), _) = Syntax.scan_varname vn
                           in TVar(("'"^b, i), S) end
-        | _ => T)
-  | thaw_vars(T) = T;
+        | _ => TFree f)
+  in map_type_tfree thaw end;
 
 
 fun restrict tye =
@@ -1041,8 +1052,8 @@
 
 
 (*Infer types for term t using tables. Check that t's type and T unify *)
-
-fun infer_term (tsig, const_type, types, sorts, T, t) =
+(*freeze determines if internal TVars are turned into TFrees or TVars*)
+fun infer_term (tsig, const_type, types, sorts, used, freeze, T, t) =
   let
     val u = attach_types (tsig, const_type, types, sorts) t;
     val (U, tye) = infer tsig ([], u, []);
@@ -1053,8 +1064,8 @@
     val all = Const("", Type("", map snd Ttye)) $ (inst_types tye' uu)
       (*all is a dummy term which contains all exported TVars*)
     val Const(_, Type(_, Ts)) $ u'' =
-      map_term_types thaw_vars (freeze (fn (_, i) => i < 0) all)
-      (*turn all internally generated TVars into TFrees
+      map_term_types thaw_vars (convert used freeze (fn (_, i) => i < 0) all)
+      (*convert all internally generated TVars into TFrees or TVars
         and thaw all initially frozen TVars*)
   in
     (u'', (map fst Ttye) ~~ Ts)
@@ -1064,4 +1075,3 @@
 
 
 end;
-