src/Pure/pattern.ML
changeset 15570 8d8c70b41bab
parent 15531 08c8dad8e399
child 15574 b1d1b5bfc464
equal deleted inserted replaced
15569:1b3115d1a8df 15570:8d8c70b41bab
    46 val trace_unify_fail = ref false;
    46 val trace_unify_fail = ref false;
    47 
    47 
    48 fun string_of_term sg env binders t = Sign.string_of_term sg
    48 fun string_of_term sg env binders t = Sign.string_of_term sg
    49   (Envir.norm_term env (subst_bounds(map Free binders,t)));
    49   (Envir.norm_term env (subst_bounds(map Free binders,t)));
    50 
    50 
    51 fun bname binders i = fst(nth_elem(i,binders));
    51 fun bname binders i = fst(List.nth(binders,i));
    52 fun bnames binders is = space_implode " " (map (bname binders) is);
    52 fun bnames binders is = space_implode " " (map (bname binders) is);
    53 
    53 
    54 fun typ_clash sg (tye,T,U) =
    54 fun typ_clash sg (tye,T,U) =
    55   if !trace_unify_fail
    55   if !trace_unify_fail
    56   then let val t = Sign.string_of_typ sg (Envir.norm_type tye T)
    56   then let val t = Sign.string_of_typ sg (Envir.norm_type tye T)
   111     in mpb 0 end;
   111     in mpb 0 end;
   112 
   112 
   113 fun idx [] j     = raise Unif
   113 fun idx [] j     = raise Unif
   114   | idx(i::is) j = if i=j then length is else idx is j;
   114   | idx(i::is) j = if i=j then length is else idx is j;
   115 
   115 
   116 fun at xs i = nth_elem (i,xs);
   116 fun at xs i = List.nth (xs,i);
   117 
   117 
   118 fun mkabs (binders,is,t)  =
   118 fun mkabs (binders,is,t)  =
   119     let fun mk(i::is) = let val (x,T) = nth_elem(i,binders)
   119     let fun mk(i::is) = let val (x,T) = List.nth(binders,i)
   120                         in Abs(x,T,mk is) end
   120                         in Abs(x,T,mk is) end
   121           | mk []     = t
   121           | mk []     = t
   122     in mk is end;
   122     in mk is end;
   123 
   123 
   124 val incr = mapbnd (fn i => i+1);
   124 val incr = mapbnd (fn i => i+1);
   156   in Envir.update((F,mkhnf(binders,is,G,js)),env') end;
   156   in Envir.update((F,mkhnf(binders,is,G,js)),env') end;
   157 
   157 
   158 
   158 
   159 (* mk_proj_list(is) = [ |is| - k | 1 <= k <= |is| and is[k] >= 0 ] *)
   159 (* mk_proj_list(is) = [ |is| - k | 1 <= k <= |is| and is[k] >= 0 ] *)
   160 fun mk_proj_list is =
   160 fun mk_proj_list is =
   161     let fun mk(i::is,j) = if is_some i then j :: mk(is,j-1) else mk(is,j-1)
   161     let fun mk(i::is,j) = if isSome i then j :: mk(is,j-1) else mk(is,j-1)
   162           | mk([],_)    = []
   162           | mk([],_)    = []
   163     in mk(is,length is - 1) end;
   163     in mk(is,length is - 1) end;
   164 
   164 
   165 fun proj(s,env,binders,is) =
   165 fun proj(s,env,binders,is) =
   166     let fun trans d i = if i<d then i else (idx is (i-d))+d;
   166     let fun trans d i = if i<d then i else (idx is (i-d))+d;
   180                          in (list_comb(Bound j,ts'),env') end
   180                          in (list_comb(Bound j,ts'),env') end
   181                  | (Var(F as (a,_),Fty),ts) =>
   181                  | (Var(F as (a,_),Fty),ts) =>
   182                       let val js = ints_of' env ts;
   182                       let val js = ints_of' env ts;
   183                           val js' = map (try (trans d)) js;
   183                           val js' = map (try (trans d)) js;
   184                           val ks = mk_proj_list js';
   184                           val ks = mk_proj_list js';
   185                           val ls = mapfilter I js'
   185                           val ls = List.mapPartial I js'
   186                           val Hty = type_of_G env (Fty,length js,ks)
   186                           val Hty = type_of_G env (Fty,length js,ks)
   187                           val (env',H) = Envir.genvar a (env,Hty)
   187                           val (env',H) = Envir.genvar a (env,Hty)
   188                           val env'' =
   188                           val env'' =
   189                                 Envir.update((F,mkhnf(binders,js,H,ks)),env')
   189                                 Envir.update((F,mkhnf(binders,js,H,ks)),env')
   190                       in (app(H,ls),env'') end
   190                       in (app(H,ls),env'') end
   258       | ((Bound i,_),(Free(f,_),_))    => (clashB binders i f; raise Unif)
   258       | ((Bound i,_),(Free(f,_),_))    => (clashB binders i f; raise Unif)
   259 
   259 
   260 
   260 
   261 and rigidrigid sg (env,binders,(a,Ta),(b,Tb),ss,ts) =
   261 and rigidrigid sg (env,binders,(a,Ta),(b,Tb),ss,ts) =
   262       if a<>b then (clash a b; raise Unif)
   262       if a<>b then (clash a b; raise Unif)
   263       else foldl (unif sg binders) (unify_types sg (Ta,Tb,env), ss~~ts)
   263       else Library.foldl (unif sg binders) (unify_types sg (Ta,Tb,env), ss~~ts)
   264 
   264 
   265 and rigidrigidB sg (env,binders,i,j,ss,ts) =
   265 and rigidrigidB sg (env,binders,i,j,ss,ts) =
   266      if i <> j then (clashBB binders i j; raise Unif)
   266      if i <> j then (clashBB binders i j; raise Unif)
   267      else foldl (unif sg binders) (env ,ss~~ts)
   267      else Library.foldl (unif sg binders) (env ,ss~~ts)
   268 
   268 
   269 and flexrigid sg (params as (env,binders,F,is,t)) =
   269 and flexrigid sg (params as (env,binders,F,is,t)) =
   270       if occurs(F,t,env) then (ocheck_fail sg (F,t,binders,env); raise Unif)
   270       if occurs(F,t,env) then (ocheck_fail sg (F,t,binders,env); raise Unif)
   271       else (let val (u,env') = proj(t,env,binders,is)
   271       else (let val (u,env') = proj(t,env,binders,is)
   272             in Envir.update((F,mkabs(binders,is,u)),env') end
   272             in Envir.update((F,mkabs(binders,is,u)),env') end
   273             handle Unif => (proj_fail sg params; raise Unif));
   273             handle Unif => (proj_fail sg params; raise Unif));
   274 
   274 
   275 fun unify(sg,env,tus) = foldl (unif sg []) (env,tus);
   275 fun unify(sg,env,tus) = Library.foldl (unif sg []) (env,tus);
   276 
   276 
   277 
   277 
   278 (*Eta-contract a term (fully)*)
   278 (*Eta-contract a term (fully)*)
   279 
   279 
   280 fun eta_contract t =
   280 fun eta_contract t =
   405             | _ => cases(binders,env,pat,obj))
   405             | _ => cases(binders,env,pat,obj))
   406 
   406 
   407   and cases(binders,env as (iTs,itms),pat,obj) =
   407   and cases(binders,env as (iTs,itms),pat,obj) =
   408     let val (ph,pargs) = strip_comb pat
   408     let val (ph,pargs) = strip_comb pat
   409         fun rigrig1(iTs,oargs) =
   409         fun rigrig1(iTs,oargs) =
   410               foldl (mtch binders) ((iTs,itms), pargs~~oargs)
   410               Library.foldl (mtch binders) ((iTs,itms), pargs~~oargs)
   411         fun rigrig2((a,Ta),(b,Tb),oargs) =
   411         fun rigrig2((a,Ta),(b,Tb),oargs) =
   412               if a<> b then raise MATCH
   412               if a<> b then raise MATCH
   413               else rigrig1(typ_match tsg (iTs,(Ta,Tb)), oargs)
   413               else rigrig1(typ_match tsg (iTs,(Ta,Tb)), oargs)
   414     in case ph of
   414     in case ph of
   415          Var(ixn,_) =>
   415          Var(ixn,_) =>
   474 fun rewrite_term tsig rules procs tm =
   474 fun rewrite_term tsig rules procs tm =
   475   let
   475   let
   476     val skel0 = Bound 0;
   476     val skel0 = Bound 0;
   477 
   477 
   478     val rhs_names =
   478     val rhs_names =
   479       foldr (fn ((_, rhs), names) => add_term_free_names (rhs, names)) (rules, []);
   479       Library.foldr (fn ((_, rhs), names) => add_term_free_names (rhs, names)) (rules, []);
   480 
   480 
   481     fun variant_absfree (x, T, t) =
   481     fun variant_absfree (x, T, t) =
   482       let
   482       let
   483         val x' = variant (add_term_free_names (t, rhs_names)) x;
   483         val x' = variant (add_term_free_names (t, rhs_names)) x;
   484         val t' = subst_bound (Free (x', T), t);
   484         val t' = subst_bound (Free (x', T), t);
   485       in (fn u => Abs (x, T, abstract_over (Free (x', T), u)), t') end;
   485       in (fn u => Abs (x, T, abstract_over (Free (x', T), u)), t') end;
   486 
   486 
   487     fun match_rew tm (tm1, tm2) =
   487     fun match_rew tm (tm1, tm2) =
   488       let val rtm = if_none (Term.rename_abs tm1 tm tm2) tm2
   488       let val rtm = getOpt (Term.rename_abs tm1 tm tm2, tm2)
   489       in SOME (subst_vars (match tsig (tm1, tm)) rtm, rtm)
   489       in SOME (subst_vars (match tsig (tm1, tm)) rtm, rtm)
   490         handle MATCH => NONE
   490         handle MATCH => NONE
   491       end;
   491       end;
   492 
   492 
   493     fun rew (Abs (_, _, body) $ t) = SOME (subst_bound (t, body), skel0)
   493     fun rew (Abs (_, _, body) $ t) = SOME (subst_bound (t, body), skel0)
   494       | rew tm = (case get_first (match_rew tm) rules of
   494       | rew tm = (case get_first (match_rew tm) rules of
   495           NONE => apsome (rpair skel0) (get_first (fn p => p tm) procs)
   495           NONE => Option.map (rpair skel0) (get_first (fn p => p tm) procs)
   496         | x => x);
   496         | x => x);
   497 
   497 
   498     fun rew1 (Var _) _ = NONE
   498     fun rew1 (Var _) _ = NONE
   499       | rew1 skel tm = (case rew2 skel tm of
   499       | rew1 skel tm = (case rew2 skel tm of
   500           SOME tm1 => (case rew tm1 of
   500           SOME tm1 => (case rew tm1 of
   501               SOME (tm2, skel') => SOME (if_none (rew1 skel' tm2) tm2)
   501               SOME (tm2, skel') => SOME (getOpt (rew1 skel' tm2, tm2))
   502             | NONE => SOME tm1)
   502             | NONE => SOME tm1)
   503         | NONE => (case rew tm of
   503         | NONE => (case rew tm of
   504               SOME (tm1, skel') => SOME (if_none (rew1 skel' tm1) tm1)
   504               SOME (tm1, skel') => SOME (getOpt (rew1 skel' tm1, tm1))
   505             | NONE => NONE))
   505             | NONE => NONE))
   506 
   506 
   507     and rew2 skel (tm1 $ tm2) = (case tm1 of
   507     and rew2 skel (tm1 $ tm2) = (case tm1 of
   508             Abs (_, _, body) =>
   508             Abs (_, _, body) =>
   509               let val tm' = subst_bound (tm2, body)
   509               let val tm' = subst_bound (tm2, body)
   510               in SOME (if_none (rew2 skel0 tm') tm') end
   510               in SOME (getOpt (rew2 skel0 tm', tm')) end
   511           | _ =>
   511           | _ =>
   512             let val (skel1, skel2) = (case skel of
   512             let val (skel1, skel2) = (case skel of
   513                 skel1 $ skel2 => (skel1, skel2)
   513                 skel1 $ skel2 => (skel1, skel2)
   514               | _ => (skel0, skel0))
   514               | _ => (skel0, skel0))
   515             in case rew1 skel1 tm1 of
   515             in case rew1 skel1 tm1 of
   528               SOME tm'' => SOME (abs tm'')
   528               SOME tm'' => SOME (abs tm'')
   529             | NONE => NONE
   529             | NONE => NONE
   530           end
   530           end
   531       | rew2 _ _ = NONE
   531       | rew2 _ _ = NONE
   532 
   532 
   533   in if_none (rew1 skel0 tm) tm end;
   533   in getOpt (rew1 skel0 tm, tm) end;
   534 
   534 
   535 end;
   535 end;
   536 
   536 
   537 val trace_unify_fail = Pattern.trace_unify_fail;
   537 val trace_unify_fail = Pattern.trace_unify_fail;