Modified pattern.ML to perform proper matching of Higher-Order Patterns.
authornipkow
Wed, 02 Nov 1994 09:09:30 +0100
changeset 678 6151b7f3b606
parent 677 dbb8431184f9
child 679 a682bbf70dc6
Modified pattern.ML to perform proper matching of Higher-Order Patterns. Modified thm.ML to preserve bound var names during rewriting. Renamed eta_matches to matches.
src/Pure/goals.ML
src/Pure/pattern.ML
src/Pure/thm.ML
--- a/src/Pure/goals.ML	Tue Nov 01 10:40:10 1994 +0100
+++ b/src/Pure/goals.ML	Wed Nov 02 09:09:30 1994 +0100
@@ -151,13 +151,13 @@
             else if not (null hyps) then result_error state
                 ("Additional hypotheses:\n" ^ 
                  cat_lines (map (Sign.string_of_term sign) hyps))
-	    else if Pattern.eta_matches (#tsig(Sign.rep_sg sign)) 
-			                (term_of chorn, prop)
+	    else if Pattern.matches (#tsig(Sign.rep_sg sign)) 
+			            (term_of chorn, prop)
 		 then  standard th 
 	    else  result_error state "proved a different theorem"
         end
-  in  
-     if Sign.eq_sg(sign, #sign(rep_thm st0)) 
+  in
+     if Sign.eq_sg(sign, #sign(rep_thm st0))
      then (prems, st0, mkresult)
      else error ("Definitions would change the proof state's signature" ^
 		 sign_error (sign, #sign(rep_thm st0)))
--- a/src/Pure/pattern.ML	Tue Nov 01 10:40:10 1994 +0100
+++ b/src/Pure/pattern.ML	Wed Nov 02 09:09:30 1994 +0100
@@ -8,6 +8,8 @@
 See also:
 Tobias Nipkow. Functional Unification of Higher-Order Patterns.
 In Proceedings of the 8th IEEE Symposium Logic in Computer Science, 1993.
+
+TODO: optimize red by special-casing it
 *)
 
 signature PATTERN =
@@ -18,7 +20,7 @@
   val eta_contract: term -> term
   val match: type_sig -> term * term
         -> (indexname*typ)list * (indexname*term)list
-  val eta_matches: type_sig -> term * term -> bool
+  val matches: type_sig -> term * term -> bool
   val unify: sg * env * (term * term)list -> env
   exception Unif
   exception MATCH
@@ -46,17 +48,12 @@
           | occ _           = false
     in occ t end;
 
-(* Something's wrong *)
-fun ill_formed s = error ("Ill-formed argument in "^s);
-
 
 fun mapbnd f =
     let fun mpb d (Bound(i))     = if i < d then Bound(i) else Bound(f(i-d)+d)
-          | mpb d (Free(c,T))    = Free(c,T)
-          | mpb d (Const(c,T))   = Const(c,T)
-          | mpb d (Var(iname,T)) = Var(iname,T)
           | mpb d (Abs(s,T,t))   = Abs(s,T,mpb(d+1) t)
-          | mpb d ((u1 $ u2))    = mpb d (u1)$ mpb d (u2)
+          | mpb d ((u1 $ u2))    = (mpb d u1)$(mpb d u2)
+          | mpb _ atom           = atom
     in mpb 0 end;
 
 fun idx [] j     = ~10000
@@ -74,7 +71,6 @@
 
 val incr = mapbnd (fn i => i+1);
 
-(* termlist --> intlist *)
 fun ints_of []             = []
   | ints_of (Bound i ::bs) = 
       let val is = ints_of bs
@@ -86,12 +82,14 @@
   | app (s,[])      = s;
 
 fun red (Abs(_,_,s)) (i::is) js = red s is (i::js)
-  | red s            is      jn = app (mapbnd (at jn) s,is);
+  | red t            []      [] = t
+  | red t            is      jn = app (mapbnd (at jn) t,is);
+
 
 (* split_type ([T1,....,Tn]---> T,n,[]) = ([Tn,...,T1],T) *)
 fun split_type (T,0,Ts)                    = (Ts,T)
   | split_type (Type ("fun",[T1,T2]),n,Ts) = split_type (T2,n-1,T1::Ts)
-  | split_type _                           = ill_formed("split_type");
+  | split_type _                           = error("split_type");
 
 fun type_of_G (T,n,is) =
   let val (Ts,U) = split_type(T,n,[]) in map(at Ts)is ---> U end;
@@ -161,7 +159,7 @@
     let fun mk([],[],_)        = [] 
           | mk(i::is,j::js, k) = if i=j then k :: mk(is,js,k-1)
                                         else mk(is,js,k-1)
-          | mk _               = ill_formed"mk_ff_list"
+          | mk _               = error"mk_ff_list"
     in mk(is,js,length is-1) end;
 
 fun flexflex1(env,binders,F,Fty,is,js) =
@@ -238,71 +236,8 @@
   | eta_contract t = t;
 
 
-(* Pattern matching. Raises MATCH if non-pattern *)
+(* Pattern matching *)
 exception MATCH;
-(* something wron with types, esp in abstractions
-fun typ_match args = Type.typ_match (!tsgr) args
-                     handle Type.TYPE_MATCH => raise MATCH;
-
-fun match_bind(itms,binders,ixn,is,t) =
-  let val js = loose_bnos t
-  in if null is
-     then if null js then (ixn,t)::itms else raise MATCH
-     else if js subset is
-          then let val t' = if downto0(is,length binders - 1) then t
-                            else mapbnd (idx is) t
-               in (ixn, eta_contract(mkabs(binders,is,t'))) :: itms end
-          else raise MATCH
-  end;
-
-fun match_rr (iTs,(a,Ta),(b,Tb)) =
-      if a<>b then raise MATCH else typ_match (iTs,(Ta,Tb))
-
-(* Pre: pat and obj have same type *)
-fun mtch(binders,env as (iTs,itms),pat,obj) = case pat of
-      Var(ixn,_) => (case assoc(itms,ixn) of
-                       None => (iTs,match_bind(itms,binders,ixn,[],obj))
-                     | Some u => if obj aconv u then env else raise MATCH)
-    | Abs(ns,Ts,ts) =>
-        (case obj of
-           Abs(nt,Tt,tt) => mtch((nt,Tt)::binders,env,ts,tt)
-         | _ => let val Tt = typ_subst_TVars iTs Ts
-                in  mtch((ns,Tt)::binders,env,ts,(incr obj)$Bound(0)) end)
-    | _ => (case obj of
-              Abs(nt,Tt,tt) =>
-                mtch((nt,Tt)::binders,env,(incr pat)$Bound(0),tt)
-            | _ => cases(binders,env,pat,obj))
-
-and cases(binders,env as (iTs,itms),pat,obj) =
-  let fun structural() = case (pat,obj) of
-            (Const c,Const d) => (match_rr(iTs,c,d),itms)
-          | (Free f,Free g)   => (match_rr(iTs,f,g),itms)
-          | (Bound i,Bound j) => if i=j then env else raise MATCH
-          | (f$t,g$u)         => mtch(binders,mtch(binders,env,t,u),f,g)
-          | _                 => raise MATCH
-  in case strip_comb pat of
-       (Var(ixn,_),bs) =>
-         (let val is = ints_of bs
-          in case assoc(itms,ixn) of
-               None => (iTs,match_bind(itms,binders,ixn,is,obj))
-             | Some u => if obj aconv (red u is []) then env else raise MATCH
-          end (* if ints_of fails: *) handle Pattern => structural())
-     | _ => structural()
-  end;
-
-fun match tsg = (tsgr := tsg;
-                 fn (pat,obj) => 
-                   let val pT = fastype_of pat
-                       and oT = fastype_of obj
-                       val iTs = typ_match ([],(pT,oT))
-                   in mtch([], (iTs,[]), pat, eta_contract obj)
-                      handle Pattern => raise MATCH
-                   end)
-
-(*Predicate: does the pattern match the object?*)
-fun matches tsig args = (match tsig args; true)
-                        handle MATCH => false;
-*)
 
 (*First-order matching;  term_match tsig (pattern, object)
     returns a (tyvar,typ)list and (var,term)list.
@@ -310,7 +245,7 @@
   Instantiation does not affect the object, so matching ?a with ?a+1 works.
   A Const does not match a Free of the same name! 
   Does not notice eta-equality, thus f does not match %(x)f(x)  *)
-fun match tsig (pat,obj) =
+fun fomatch tsig (pat,obj) =
   let fun typ_match args = (Type.typ_match tsig args)
 			   handle Type.TYPE_MATCH => raise MATCH;
       fun mtch (tyinsts,insts) = fn
@@ -331,11 +266,90 @@
 	  mtch (typ_match (tyinsts,(T,U)),insts) (t,u)
       | (f$t, g$u) => mtch (mtch (tyinsts,insts) (f,g)) (t, u)
       | _ => raise MATCH
-  in mtch([],[]) (pat,obj) end;
+  in mtch([],[]) (eta_contract pat,eta_contract obj) end;
+
+
+fun match_bind(itms,binders,ixn,is,t) =
+  let val js = loose_bnos t
+  in if null is
+     then if null js then (ixn,t)::itms else raise MATCH
+     else if js subset is
+          then let val t' = if downto0(is,length binders - 1) then t
+                            else mapbnd (idx is) t
+               in (ixn, mkabs(binders,is,t')) :: itms end
+          else raise MATCH
+  end;
+
+(*Tests whether 2 terms are alpha/eta-convertible and have same type.
+  Note that Consts and Vars may have more than one type.*)
+infix aeconv;
+fun (Const(a,T)) aeconv (Const(b,U)) = a=b  andalso  T=U
+  | (Free(a,T))  aeconv (Free(b,U))  = a=b  andalso  T=U
+  | (Var(v,T))   aeconv (Var(w,U))   = v=w  andalso  T=U
+  | (Bound i)    aeconv (Bound j)    = i=j
+  | (Abs(_,T,t)) aeconv (Abs(_,U,u)) = (t aeconv u)  andalso  T=U
+  | (Abs(_,T,t)) aeconv u            = t aeconv ((incr u)$Bound(0))
+  | t            aeconv (Abs(_,U,u)) = ((incr t)$Bound(0)) aeconv u
+  | (f$t)        aeconv (g$u)        = (f aeconv g)  andalso (t aeconv u)
+  | _ aeconv _                       =  false;
+
+
+fun match tsg (po as (pat,obj)) =
+let
+
+fun typ_match args = Type.typ_match tsg args
+                     handle Type.TYPE_MATCH => raise MATCH;
+
+(* Pre: pat and obj have same type *)
+fun mtch binders (env as (iTs,itms),(pat,obj)) =
+      case pat of
+        Abs(ns,Ts,ts) =>
+          (case obj of
+             Abs(nt,Tt,tt) => mtch ((nt,Tt)::binders) (env,(ts,tt))
+           | _ => let val Tt = typ_subst_TVars iTs Ts
+                  in mtch((ns,Tt)::binders)(env,(ts,(incr obj)$Bound(0))) end)
+      | _ => (case obj of
+                Abs(nt,Tt,tt) =>
+                  mtch((nt,Tt)::binders)(env,((incr pat)$Bound(0),tt))
+              | _ => cases(binders,env,pat,obj))
+
+and cases(binders,env as (iTs,itms),pat,obj) =
+      let val (ph,pargs) = strip_comb pat
+          fun rigrig1(iTs,oargs) =
+                foldl (mtch binders) ((iTs,itms), pargs~~oargs)
+          fun rigrig2((a,Ta),(b,Tb),oargs) =
+                if a<> b then raise MATCH
+                else rigrig1(typ_match(iTs,(Ta,Tb)), oargs)
+      in case ph of
+           Var(ixn,_) =>
+             let val is = ints_of pargs
+             in case assoc(itms,ixn) of
+                  None => (iTs,match_bind(itms,binders,ixn,is,obj))
+                | Some u => if obj aeconv (red u is []) then env
+                            else raise MATCH
+             end
+         | _ =>
+             let val (oh,oargs) = strip_comb obj
+             in case (ph,oh) of
+                  (Const c,Const d) => rigrig2(c,d,oargs)
+                | (Free f,Free g)   => rigrig2(f,g,oargs)
+                | (Bound i,Bound j) => if i<>j then raise MATCH
+                                       else rigrig1(iTs,oargs)
+                | (Abs _, _)        => raise Pattern
+                | (_, Abs _)        => raise Pattern
+                | _                 => raise MATCH
+             end
+      end;
+
+val pT = fastype_of pat
+and oT = fastype_of obj
+val iTs = typ_match ([],(pT,oT))
+
+in mtch [] ((iTs,[]), po)
+   handle Pattern => fomatch tsg po
+end;
 
 (*Predicate: does the pattern match the object?*)
-fun eta_matches tsig (pat,obj) =
-      (match tsig (eta_contract pat,eta_contract obj); true)
-      handle MATCH => false;
+fun matches tsig po = (match tsig po; true) handle MATCH => false;
 
 end;
--- a/src/Pure/thm.ML	Tue Nov 01 10:40:10 1994 +0100
+++ b/src/Pure/thm.ML	Wed Nov 02 09:09:30 1994 +0100
@@ -1031,15 +1031,18 @@
 fun loops sign prems (lhs,rhs) =
   is_Var(lhs) orelse
   (null(prems) andalso
-   Pattern.eta_matches (#tsig(Sign.rep_sg sign)) (lhs,rhs));
+   Pattern.matches (#tsig(Sign.rep_sg sign)) (lhs,rhs));
 
 fun mk_rrule (thm as Thm{hyps,sign,prop,maxidx,...}) =
   let val prems = Logic.strip_imp_prems prop
-      val concl = Pattern.eta_contract (Logic.strip_imp_concl prop)
-      val (lhs,rhs) = Logic.dest_equals concl handle TERM _ =>
+      val concl = Logic.strip_imp_concl prop
+      val (lhs,_) = Logic.dest_equals concl handle TERM _ =>
                       raise SIMPLIFIER("Rewrite rule not a meta-equality",thm)
-      val perm = var_perm(lhs,rhs) andalso not(lhs=rhs orelse is_Var(lhs))
-  in if not perm andalso loops sign prems (lhs,rhs)
+      val econcl = Pattern.eta_contract concl
+      val (elhs,erhs) = Logic.dest_equals econcl
+      val perm = var_perm(elhs,erhs) andalso not(elhs aconv erhs)
+                                     andalso not(is_Var(elhs))
+  in if not perm andalso loops sign prems (elhs,erhs)
      then (prtm "Warning: ignoring looping rewrite rule" sign prop; None)
      else Some{thm=thm,lhs=lhs,perm=perm}
   end;
@@ -1082,7 +1085,7 @@
 fun add_cong(Mss{net,congs,bounds,prems,mk_rews},thm) =
   let val (lhs,_) = Logic.dest_equals(concl_of thm) handle TERM _ =>
                     raise SIMPLIFIER("Congruence not a meta-equality",thm)
-      val lhs = Pattern.eta_contract lhs
+(*      val lhs = Pattern.eta_contract lhs*)
       val (a,_) = dest_Const (head_of lhs) handle TERM _ =>
                   raise SIMPLIFIER("Congruence must start with a constant",thm)
   in Mss{net=net, congs=(a,{lhs=lhs,thm=thm})::congs, bounds=bounds,
@@ -1169,9 +1172,6 @@
      | _ => err()
   end;
 
-(* This code doesn't help at the moment because many bound vars in rewrite
-   rules are eliminated by eta-contraction. Thus the names of bound vars are
-   lost upon rewriting.
 fun ren_inst(insts,prop,pat,obj) =
   let val ren = match_bvs(pat,obj,[])
       fun renAbs(Abs(x,T,b)) =
@@ -1179,18 +1179,17 @@
         | renAbs(f$t) = renAbs(f) $ renAbs(t)
         | renAbs(t) = t
   in subst_vars insts (if null(ren) then prop else renAbs(prop)) end;
-*)
-fun ren_inst(insts,prop,_,_) = subst_vars insts prop;
+
 
 (*Conversion to apply the meta simpset to a term*)
 fun rewritec (prover,signt) (mss as Mss{net,...}) (hypst,t) =
-  let val t = Pattern.eta_contract t;
+  let val etat = Pattern.eta_contract t;
       fun rew {thm as Thm{sign,hyps,maxidx,prop,...}, lhs, perm} =
         let val unit = if Sign.subsig(sign,signt) then ()
                   else (trace_thm"Warning: rewrite rule from different theory"
                           thm;
                         raise Pattern.MATCH)
-            val insts = Pattern.match (#tsig(Sign.rep_sg signt)) (lhs,t)
+            val insts = Pattern.match (#tsig(Sign.rep_sg signt)) (lhs,etat)
             val prop' = ren_inst(insts,prop,lhs,t);
             val hyps' = hyps union hypst;
             val thm' = Thm{sign=signt, hyps=hyps', prop=prop', maxidx=maxidx}
@@ -1209,9 +1208,9 @@
             let val opt = rew rrule handle Pattern.MATCH => None
             in case opt of None => rews rrules | some => some end;
 
-  in case t of
+  in case etat of
        Abs(_,_,body) $ u => Some(hypst,subst_bounds([u], body))
-     | _                 => rews(Net.match_term net t)
+     | _                 => rews(Net.match_term net etat)
   end;
 
 (*Conversion to apply a congruence rule to a term*)