src/Provers/splitter.ML
author haftmann
Tue Nov 24 17:28:25 2009 +0100 (2009-11-24)
changeset 33955 fff6f11b1f09
parent 33317 b4534348b8fd
child 35625 9c818cab0dd0
permissions -rw-r--r--
curried take/drop
     1 (*  Title:      Provers/splitter.ML
     2     Author:     Tobias Nipkow
     3     Copyright   1995  TU Munich
     4 
     5 Generic case-splitter, suitable for most logics.
     6 Deals with equalities of the form ?P(f args) = ...
     7 where "f args" must be a first-order term without duplicate variables.
     8 *)
     9 
    10 infix 4 addsplits delsplits;
    11 
    12 signature SPLITTER_DATA =
    13 sig
    14   val thy           : theory
    15   val mk_eq         : thm -> thm
    16   val meta_eq_to_iff: thm (* "x == y ==> x = y"                      *)
    17   val iffD          : thm (* "[| P = Q; Q |] ==> P"                  *)
    18   val disjE         : thm (* "[| P | Q; P ==> R; Q ==> R |] ==> R"   *)
    19   val conjE         : thm (* "[| P & Q; [| P; Q |] ==> R |] ==> R"   *)
    20   val exE           : thm (* "[| EX x. P x; !!x. P x ==> Q |] ==> Q" *)
    21   val contrapos     : thm (* "[| ~ Q; P ==> Q |] ==> ~ P"            *)
    22   val contrapos2    : thm (* "[| Q; ~ P ==> ~ Q |] ==> P"            *)
    23   val notnotD       : thm (* "~ ~ P ==> P"                           *)
    24 end
    25 
    26 signature SPLITTER =
    27 sig
    28   (* somewhat more internal functions *)
    29   val cmap_of_split_thms: thm list -> (string * (typ * term * thm * typ * int) list) list
    30   val split_posns: (string * (typ * term * thm * typ * int) list) list ->
    31     theory -> typ list -> term -> (thm * (typ * typ * int list) list * int list * typ * term) list
    32     (* first argument is a "cmap", returns a list of "split packs" *)
    33   (* the "real" interface, providing a number of tactics *)
    34   val split_tac       : thm list -> int -> tactic
    35   val split_inside_tac: thm list -> int -> tactic
    36   val split_asm_tac   : thm list -> int -> tactic
    37   val addsplits       : simpset * thm list -> simpset
    38   val delsplits       : simpset * thm list -> simpset
    39   val split_add: attribute
    40   val split_del: attribute
    41   val split_modifiers : Method.modifier parser list
    42   val setup: theory -> theory
    43 end;
    44 
    45 functor Splitter(Data: SPLITTER_DATA): SPLITTER =
    46 struct
    47 
    48 val Const (const_not, _) $ _ =
    49   ObjectLogic.drop_judgment Data.thy
    50     (#1 (Logic.dest_implies (Thm.prop_of Data.notnotD)));
    51 
    52 val Const (const_or , _) $ _ $ _ =
    53   ObjectLogic.drop_judgment Data.thy
    54     (#1 (Logic.dest_implies (Thm.prop_of Data.disjE)));
    55 
    56 val const_Trueprop = ObjectLogic.judgment_name Data.thy;
    57 
    58 
    59 fun split_format_err () = error "Wrong format for split rule";
    60 
    61 fun split_thm_info thm = case concl_of (Data.mk_eq thm) of
    62      Const("==", _) $ (Var _ $ t) $ c => (case strip_comb t of
    63        (Const p, _) => (p, case c of (Const (s, _) $ _) => s = const_not | _ => false)
    64      | _ => split_format_err ())
    65    | _ => split_format_err ();
    66 
    67 fun cmap_of_split_thms thms =
    68 let
    69   val splits = map Data.mk_eq thms
    70   fun add_thm thm cmap =
    71     (case concl_of thm of _ $ (t as _ $ lhs) $ _ =>
    72        (case strip_comb lhs of (Const(a,aT),args) =>
    73           let val info = (aT,lhs,thm,fastype_of t,length args)
    74           in case AList.lookup (op =) cmap a of
    75                SOME infos => AList.update (op =) (a, info::infos) cmap
    76              | NONE => (a,[info])::cmap
    77           end
    78         | _ => split_format_err())
    79      | _ => split_format_err())
    80 in
    81   fold add_thm splits []
    82 end;
    83 
    84 (* ------------------------------------------------------------------------- *)
    85 (* mk_case_split_tac                                                         *)
    86 (* ------------------------------------------------------------------------- *)
    87 
    88 fun mk_case_split_tac order =
    89 let
    90 
    91 (************************************************************
    92    Create lift-theorem "trlift" :
    93 
    94    [| !!x. Q x == R x; P(%x. R x) == C |] ==> P (%x. Q x) == C
    95 
    96 *************************************************************)
    97 
    98 val meta_iffD = Data.meta_eq_to_iff RS Data.iffD;  (* (P == Q) ==> Q ==> P *)
    99 
   100 val lift = Goal.prove_global Pure.thy ["P", "Q", "R"]
   101   [Syntax.read_prop_global Pure.thy "!!x :: 'b. Q(x) == R(x) :: 'c"]
   102   (Syntax.read_prop_global Pure.thy "P(%x. Q(x)) == P(%x. R(x))")
   103   (fn {prems, ...} => rewrite_goals_tac prems THEN rtac reflexive_thm 1)
   104 
   105 val trlift = lift RS transitive_thm;
   106 val _ $ (P $ _) $ _ = concl_of trlift;
   107 
   108 
   109 (************************************************************************
   110    Set up term for instantiation of P in the lift-theorem
   111 
   112    Ts    : types of parameters (i.e. variables bound by meta-quantifiers)
   113    t     : lefthand side of meta-equality in subgoal
   114            the lift theorem is applied to (see select)
   115    pos   : "path" leading to abstraction, coded as a list
   116    T     : type of body of P(...)
   117    maxi  : maximum index of Vars
   118 *************************************************************************)
   119 
   120 fun mk_cntxt Ts t pos T maxi =
   121   let fun var (t,i) = Var(("X",i),type_of1(Ts,t));
   122       fun down [] t i = Bound 0
   123         | down (p::ps) t i =
   124             let val (h,ts) = strip_comb t
   125                 val v1 = ListPair.map var (take p ts, i upto (i+p-1))
   126                 val u::us = drop p ts
   127                 val v2 = ListPair.map var (us, (i+p) upto (i+length(ts)-2))
   128       in list_comb(h,v1@[down ps u (i+length ts)]@v2) end;
   129   in Abs("", T, down (rev pos) t maxi) end;
   130 
   131 
   132 (************************************************************************
   133    Set up term for instantiation of P in the split-theorem
   134    P(...) == rhs
   135 
   136    t     : lefthand side of meta-equality in subgoal
   137            the split theorem is applied to (see select)
   138    T     : type of body of P(...)
   139    tt    : the term  Const(key,..) $ ...
   140 *************************************************************************)
   141 
   142 fun mk_cntxt_splitthm t tt T =
   143   let fun repl lev t =
   144     if Pattern.aeconv(incr_boundvars lev tt, t) then Bound lev
   145     else case t of
   146         (Abs (v, T2, t)) => Abs (v, T2, repl (lev+1) t)
   147       | (Bound i) => Bound (if i>=lev then i+1 else i)
   148       | (t1 $ t2) => (repl lev t1) $ (repl lev t2)
   149       | t => t
   150   in Abs("", T, repl 0 t) end;
   151 
   152 
   153 (* add all loose bound variables in t to list is *)
   154 fun add_lbnos t is = add_loose_bnos (t, 0, is);
   155 
   156 (* check if the innermost abstraction that needs to be removed
   157    has a body of type T; otherwise the expansion thm will fail later on
   158 *)
   159 fun type_test (T, lbnos, apsns) =
   160   let val (_, U: typ, _) = List.nth (apsns, foldl1 Int.min lbnos)
   161   in T = U end;
   162 
   163 (*************************************************************************
   164    Create a "split_pack".
   165 
   166    thm   : the relevant split-theorem, i.e. P(...) == rhs , where P(...)
   167            is of the form
   168            P( Const(key,...) $ t_1 $ ... $ t_n )      (e.g. key = "if")
   169    T     : type of P(...)
   170    T'    : type of term to be scanned
   171    n     : number of arguments expected by Const(key,...)
   172    ts    : list of arguments actually found
   173    apsns : list of tuples of the form (T,U,pos), one tuple for each
   174            abstraction that is encountered on the way to the position where
   175            Const(key, ...) $ ...  occurs, where
   176            T   : type of the variable bound by the abstraction
   177            U   : type of the abstraction's body
   178            pos : "path" leading to the body of the abstraction
   179    pos   : "path" leading to the position where Const(key, ...) $ ...  occurs.
   180    TB    : type of  Const(key,...) $ t_1 $ ... $ t_n
   181    t     : the term Const(key,...) $ t_1 $ ... $ t_n
   182 
   183    A split pack is a tuple of the form
   184    (thm, apsns, pos, TB, tt)
   185    Note : apsns is reversed, so that the outermost quantifier's position
   186           comes first ! If the terms in ts don't contain variables bound
   187           by other than meta-quantifiers, apsns is empty, because no further
   188           lifting is required before applying the split-theorem.
   189 ******************************************************************************)
   190 
   191 fun mk_split_pack (thm, T: typ, T', n, ts, apsns, pos, TB, t) =
   192   if n > length ts then []
   193   else let val lev = length apsns
   194            val lbnos = fold add_lbnos (take n ts) []
   195            val flbnos = filter (fn i => i < lev) lbnos
   196            val tt = incr_boundvars (~lev) t
   197        in if null flbnos then
   198             if T = T' then [(thm,[],pos,TB,tt)] else []
   199           else if type_test(T,flbnos,apsns) then [(thm, rev apsns,pos,TB,tt)]
   200                else []
   201        end;
   202 
   203 
   204 (****************************************************************************
   205    Recursively scans term for occurences of Const(key,...) $ ...
   206    Returns a list of "split-packs" (one for each occurence of Const(key,...) )
   207 
   208    cmap : association list of split-theorems that should be tried.
   209           The elements have the format (key,(thm,T,n)) , where
   210           key : the theorem's key constant ( Const(key,...) $ ... )
   211           thm : the theorem itself
   212           T   : type of P( Const(key,...) $ ... )
   213           n   : number of arguments expected by Const(key,...)
   214    Ts   : types of parameters
   215    t    : the term to be scanned
   216 ******************************************************************************)
   217 
   218 (* Simplified first-order matching;
   219    assumes that all Vars in the pattern are distinct;
   220    see Pure/pattern.ML for the full version;
   221 *)
   222 local
   223   exception MATCH
   224 in
   225   fun typ_match sg (tyenv, TU) = (Sign.typ_match sg TU tyenv)
   226                             handle Type.TYPE_MATCH => raise MATCH
   227 
   228   fun fomatch sg args =
   229     let
   230       fun mtch tyinsts = fn
   231           (Ts, Var(_,T), t) =>
   232             typ_match sg (tyinsts, (T, fastype_of1(Ts,t)))
   233         | (_, Free (a,T), Free (b,U)) =>
   234             if a=b then typ_match sg (tyinsts,(T,U)) else raise MATCH
   235         | (_, Const (a,T), Const (b,U)) =>
   236             if a=b then typ_match sg (tyinsts,(T,U)) else raise MATCH
   237         | (_, Bound i, Bound j) =>
   238             if i=j then tyinsts else raise MATCH
   239         | (Ts, Abs(_,T,t), Abs(_,U,u)) =>
   240             mtch (typ_match sg (tyinsts,(T,U))) (U::Ts,t,u)
   241         | (Ts, f$t, g$u) =>
   242             mtch (mtch tyinsts (Ts,f,g)) (Ts, t, u)
   243         | _ => raise MATCH
   244     in (mtch Vartab.empty args; true) handle MATCH => false end;
   245 end;
   246 
   247 fun split_posns (cmap : (string * (typ * term * thm * typ * int) list) list) sg Ts t =
   248   let
   249     val T' = fastype_of1 (Ts, t);
   250     fun posns Ts pos apsns (Abs (_, T, t)) =
   251           let val U = fastype_of1 (T::Ts,t)
   252           in posns (T::Ts) (0::pos) ((T, U, pos)::apsns) t end
   253       | posns Ts pos apsns t =
   254           let
   255             val (h, ts) = strip_comb t
   256             fun iter t (i, a) = (i+1, (posns Ts (i::pos) apsns t) @ a);
   257             val a =
   258               case h of
   259                 Const(c, cT) =>
   260                   let fun find [] = []
   261                         | find ((gcT, pat, thm, T, n)::tups) =
   262                             let val t2 = list_comb (h, take n ts)
   263                             in if Sign.typ_instance sg (cT, gcT)
   264                                   andalso fomatch sg (Ts,pat,t2)
   265                                then mk_split_pack(thm,T,T',n,ts,apsns,pos,type_of1(Ts,t2),t2)
   266                                else find tups
   267                             end
   268                   in find (these (AList.lookup (op =) cmap c)) end
   269               | _ => []
   270           in snd (fold iter ts (0, a)) end
   271   in posns Ts [] [] t end;
   272 
   273 fun nth_subgoal i thm = List.nth (prems_of thm, i-1);
   274 
   275 fun shorter ((_,ps,pos,_,_), (_,qs,qos,_,_)) =
   276   prod_ord (int_ord o pairself length) (order o pairself length)
   277     ((ps, pos), (qs, qos));
   278 
   279 
   280 (************************************************************
   281    call split_posns with appropriate parameters
   282 *************************************************************)
   283 
   284 fun select cmap state i =
   285   let val sg = Thm.theory_of_thm state
   286       val goali = nth_subgoal i state
   287       val Ts = rev(map #2 (Logic.strip_params goali))
   288       val _ $ t $ _ = Logic.strip_assums_concl goali;
   289   in (Ts, t, sort shorter (split_posns cmap sg Ts t)) end;
   290 
   291 fun exported_split_posns cmap sg Ts t =
   292   sort shorter (split_posns cmap sg Ts t);
   293 
   294 (*************************************************************
   295    instantiate lift theorem
   296 
   297    if t is of the form
   298    ... ( Const(...,...) $ Abs( .... ) ) ...
   299    then
   300    P = %a.  ... ( Const(...,...) $ a ) ...
   301    where a has type T --> U
   302 
   303    Ts      : types of parameters
   304    t       : lefthand side of meta-equality in subgoal
   305              the split theorem is applied to (see cmap)
   306    T,U,pos : see mk_split_pack
   307    state   : current proof state
   308    lift    : the lift theorem
   309    i       : no. of subgoal
   310 **************************************************************)
   311 
   312 fun inst_lift Ts t (T, U, pos) state i =
   313   let
   314     val cert = cterm_of (Thm.theory_of_thm state);
   315     val cntxt = mk_cntxt Ts t pos (T --> U) (Thm.maxidx_of trlift);
   316   in cterm_instantiate [(cert P, cert cntxt)] trlift
   317   end;
   318 
   319 
   320 (*************************************************************
   321    instantiate split theorem
   322 
   323    Ts    : types of parameters
   324    t     : lefthand side of meta-equality in subgoal
   325            the split theorem is applied to (see cmap)
   326    tt    : the term  Const(key,..) $ ...
   327    thm   : the split theorem
   328    TB    : type of body of P(...)
   329    state : current proof state
   330    i     : number of subgoal
   331 **************************************************************)
   332 
   333 fun inst_split Ts t tt thm TB state i =
   334   let
   335     val thm' = Thm.lift_rule (Thm.cprem_of state i) thm;
   336     val (P, _) = strip_comb (fst (Logic.dest_equals
   337       (Logic.strip_assums_concl (Thm.prop_of thm'))));
   338     val cert = cterm_of (Thm.theory_of_thm state);
   339     val cntxt = mk_cntxt_splitthm t tt TB;
   340     val abss = fold (fn T => fn t => Abs ("", T, t));
   341   in cterm_instantiate [(cert P, cert (abss Ts cntxt))] thm'
   342   end;
   343 
   344 
   345 (*****************************************************************************
   346    The split-tactic
   347 
   348    splits : list of split-theorems to be tried
   349    i      : number of subgoal the tactic should be applied to
   350 *****************************************************************************)
   351 
   352 fun split_tac [] i = no_tac
   353   | split_tac splits i =
   354   let val cmap = cmap_of_split_thms splits
   355       fun lift_tac Ts t p st = rtac (inst_lift Ts t p st i) i st
   356       fun lift_split_tac state =
   357             let val (Ts, t, splits) = select cmap state i
   358             in case splits of
   359                  [] => no_tac state
   360                | (thm, apsns, pos, TB, tt)::_ =>
   361                    (case apsns of
   362                       [] => compose_tac (false, inst_split Ts t tt thm TB state i, 0) i state
   363                     | p::_ => EVERY [lift_tac Ts t p,
   364                                      rtac reflexive_thm (i+1),
   365                                      lift_split_tac] state)
   366             end
   367   in COND (has_fewer_prems i) no_tac
   368           (rtac meta_iffD i THEN lift_split_tac)
   369   end;
   370 
   371 in (split_tac, exported_split_posns) end;  (* mk_case_split_tac *)
   372 
   373 
   374 val (split_tac, split_posns) = mk_case_split_tac int_ord;
   375 
   376 val (split_inside_tac, _) = mk_case_split_tac (rev_order o int_ord);
   377 
   378 
   379 (*****************************************************************************
   380    The split-tactic for premises
   381 
   382    splits : list of split-theorems to be tried
   383 ****************************************************************************)
   384 fun split_asm_tac [] = K no_tac
   385   | split_asm_tac splits =
   386 
   387   let val cname_list = map (fst o fst o split_thm_info) splits;
   388       fun tac (t,i) =
   389           let val n = find_index (exists_Const (member (op =) cname_list o #1))
   390                                  (Logic.strip_assums_hyp t);
   391               fun first_prem_is_disj (Const ("==>", _) $ (Const (c, _)
   392                     $ (Const (s, _) $ _ $ _ )) $ _ ) = c = const_Trueprop andalso s = const_or
   393               |   first_prem_is_disj (Const("all",_)$Abs(_,_,t)) =
   394                                         first_prem_is_disj t
   395               |   first_prem_is_disj _ = false;
   396       (* does not work properly if the split variable is bound by a quantifier *)
   397               fun flat_prems_tac i = SUBGOAL (fn (t,i) =>
   398                            (if first_prem_is_disj t
   399                             then EVERY[etac Data.disjE i,rotate_tac ~1 i,
   400                                        rotate_tac ~1  (i+1),
   401                                        flat_prems_tac (i+1)]
   402                             else all_tac)
   403                            THEN REPEAT (eresolve_tac [Data.conjE,Data.exE] i)
   404                            THEN REPEAT (dresolve_tac [Data.notnotD]   i)) i;
   405           in if n<0 then  no_tac  else (DETERM (EVERY'
   406                 [rotate_tac n, etac Data.contrapos2,
   407                  split_tac splits,
   408                  rotate_tac ~1, etac Data.contrapos, rotate_tac ~1,
   409                  flat_prems_tac] i))
   410           end;
   411   in SUBGOAL tac
   412   end;
   413 
   414 fun gen_split_tac [] = K no_tac
   415   | gen_split_tac (split::splits) =
   416       let val (_,asm) = split_thm_info split
   417       in (if asm then split_asm_tac else split_tac) [split] ORELSE'
   418          gen_split_tac splits
   419       end;
   420 
   421 
   422 (** declare split rules **)
   423 
   424 (* addsplits / delsplits *)
   425 
   426 fun string_of_typ (Type (s, Ts)) =
   427       (if null Ts then "" else enclose "(" ")" (commas (map string_of_typ Ts))) ^ s
   428   | string_of_typ _ = "_";
   429 
   430 fun split_name (name, T) asm = "split " ^
   431   (if asm then "asm " else "") ^ name ^ " :: " ^ string_of_typ T;
   432 
   433 fun ss addsplits splits =
   434   let
   435     fun addsplit split ss =
   436       let
   437         val (name, asm) = split_thm_info split
   438         val tac = (if asm then split_asm_tac else split_tac) [split]
   439       in Simplifier.addloop (ss, (split_name name asm, tac)) end
   440   in fold addsplit splits ss end;
   441 
   442 fun ss delsplits splits =
   443   let
   444     fun delsplit split ss =
   445       let val (name, asm) = split_thm_info split
   446       in Simplifier.delloop (ss, split_name name asm) end
   447   in fold delsplit splits ss end;
   448 
   449 
   450 (* attributes *)
   451 
   452 val splitN = "split";
   453 
   454 val split_add = Simplifier.attrib (op addsplits);
   455 val split_del = Simplifier.attrib (op delsplits);
   456 
   457 
   458 (* methods *)
   459 
   460 val split_modifiers =
   461  [Args.$$$ splitN -- Args.colon >> K ((I, split_add): Method.modifier),
   462   Args.$$$ splitN -- Args.add -- Args.colon >> K (I, split_add),
   463   Args.$$$ splitN -- Args.del -- Args.colon >> K (I, split_del)];
   464 
   465 
   466 (* theory setup *)
   467 
   468 val setup =
   469   Attrib.setup @{binding split}
   470     (Attrib.add_del split_add split_del) "declare case split rule" #>
   471   Method.setup @{binding split}
   472     (Attrib.thms >> (fn ths => K (SIMPLE_METHOD' (CHANGED_PROP o gen_split_tac ths))))
   473     "apply case split rule";
   474 
   475 end;