src/Provers/splitter.ML
author wenzelm
Mon Aug 28 20:29:56 2000 +0200 (2000-08-28)
changeset 9703 bf65780eed02
parent 9267 dbf30a2d1b56
child 9807 64b7f756c8f0
permissions -rw-r--r--
added 'split' method;
     1 (*  Title:      Provers/splitter
     2     ID:         $Id$
     3     Author:     Tobias Nipkow
     4     Copyright   1995  TU Munich
     5 
     6 Generic case-splitter, suitable for most logics.
     7 *)
     8 
     9 infix 4 addsplits delsplits;
    10 
    11 signature SPLITTER_DATA =
    12 sig
    13   structure Simplifier: SIMPLIFIER
    14   val mk_eq         : thm -> thm
    15   val meta_eq_to_iff: thm (* "x == y ==> x = y"                    *)
    16   val iffD          : thm (* "[| P = Q; Q |] ==> P"                *)
    17   val disjE         : thm (* "[| P | Q; P ==> R; Q ==> R |] ==> R" *)
    18   val conjE         : thm (* "[| P & Q; [| P; Q |] ==> R |] ==> R" *)
    19   val exE           : thm (* "[|  x. P x; !!x. P x ==> Q |] ==> Q" *)
    20   val contrapos     : thm (* "[| ~ Q; P ==> Q |] ==> ~ P"          *)
    21   val contrapos2    : thm (* "[| Q; ~ P ==> ~ Q |] ==> P"          *)
    22   val notnotD       : thm (* "~ ~ P ==> P"                         *)
    23 end
    24 
    25 signature SPLITTER =
    26 sig
    27   type simpset
    28   val split_tac       : thm list -> int -> tactic
    29   val split_inside_tac: thm list -> int -> tactic
    30   val split_asm_tac   : thm list -> int -> tactic
    31   val addsplits       : simpset * thm list -> simpset
    32   val delsplits       : simpset * thm list -> simpset
    33   val Addsplits       : thm list -> unit
    34   val Delsplits       : thm list -> unit
    35   val split_add_global: theory attribute
    36   val split_del_global: theory attribute
    37   val split_add_local: Proof.context attribute
    38   val split_del_local: Proof.context attribute
    39   val split_modifiers : (Args.T list -> (Method.modifier * Args.T list)) list
    40   val setup: (theory -> theory) list
    41 end;
    42 
    43 functor SplitterFun(Data: SPLITTER_DATA): SPLITTER =
    44 struct 
    45 
    46 structure Simplifier = Data.Simplifier;
    47 type simpset = Simplifier.simpset;
    48 
    49 val Const ("==>", _) $ (Const ("Trueprop", _) $
    50          (Const (const_not, _) $ _    )) $ _ = #prop (rep_thm(Data.notnotD));
    51 
    52 val Const ("==>", _) $ (Const ("Trueprop", _) $
    53          (Const (const_or , _) $ _ $ _)) $ _ = #prop (rep_thm(Data.disjE));
    54 
    55 fun split_format_err() = error("Wrong format for split rule");
    56 
    57 fun split_thm_info thm = case concl_of (Data.mk_eq thm) of
    58      Const("==", _)$(Var _$t)$c =>
    59         (case strip_comb t of
    60            (Const(a,_),_) => (a,case c of (Const(s,_)$_)=>s=const_not|_=> false)
    61          | _              => split_format_err())
    62    | _ => split_format_err();
    63 
    64 fun mk_case_split_tac order =
    65 let
    66 
    67 
    68 (************************************************************
    69    Create lift-theorem "trlift" :
    70 
    71    [| !!x. Q x == R x; P(%x. R x) == C |] ==> P (%x. Q x) == C
    72 
    73 *************************************************************)
    74 
    75 val meta_iffD = Data.meta_eq_to_iff RS Data.iffD;
    76 val lift =
    77   let val ct = read_cterm (#sign(rep_thm Data.iffD))
    78            ("[| !!x::'b::logic. Q(x) == R(x) |] ==> \
    79             \P(%x. Q(x)) == P(%x. R(x))::'a::logic",propT)
    80   in prove_goalw_cterm [] ct
    81      (fn [prem] => [rewtac prem, rtac reflexive_thm 1])
    82   end;
    83 
    84 val trlift = lift RS transitive_thm;
    85 val _ $ (P $ _) $ _ = concl_of trlift;
    86 
    87 
    88 (************************************************************************ 
    89    Set up term for instantiation of P in the lift-theorem
    90    
    91    Ts    : types of parameters (i.e. variables bound by meta-quantifiers)
    92    t     : lefthand side of meta-equality in subgoal
    93            the lift theorem is applied to (see select)
    94    pos   : "path" leading to abstraction, coded as a list
    95    T     : type of body of P(...)
    96    maxi  : maximum index of Vars
    97 *************************************************************************)
    98 
    99 fun mk_cntxt Ts t pos T maxi =
   100   let fun var (t,i) = Var(("X",i),type_of1(Ts,t));
   101       fun down [] t i = Bound 0
   102         | down (p::ps) t i =
   103             let val (h,ts) = strip_comb t
   104                 val v1 = ListPair.map var (take(p,ts), i upto (i+p-1))
   105                 val u::us = drop(p,ts)
   106                 val v2 = ListPair.map var (us, (i+p) upto (i+length(ts)-2))
   107       in list_comb(h,v1@[down ps u (i+length ts)]@v2) end;
   108   in Abs("", T, down (rev pos) t maxi) end;
   109 
   110 
   111 (************************************************************************ 
   112    Set up term for instantiation of P in the split-theorem
   113    P(...) == rhs
   114 
   115    t     : lefthand side of meta-equality in subgoal
   116            the split theorem is applied to (see select)
   117    T     : type of body of P(...)
   118    tt    : the term  Const(key,..) $ ...
   119 *************************************************************************)
   120 
   121 fun mk_cntxt_splitthm t tt T =
   122   let fun repl lev t =
   123     if incr_boundvars lev tt aconv t then Bound lev
   124     else case t of
   125         (Abs (v, T2, t)) => Abs (v, T2, repl (lev+1) t)
   126       | (Bound i) => Bound (if i>=lev then i+1 else i)
   127       | (t1 $ t2) => (repl lev t1) $ (repl lev t2)
   128       | t => t
   129   in Abs("", T, repl 0 t) end;
   130 
   131 
   132 (* add all loose bound variables in t to list is *)
   133 fun add_lbnos(is,t) = add_loose_bnos(t,0,is);
   134 
   135 (* check if the innermost abstraction that needs to be removed
   136    has a body of type T; otherwise the expansion thm will fail later on
   137 *)
   138 fun type_test(T,lbnos,apsns) =
   139   let val (_,U,_) = nth_elem(foldl Int.min (hd lbnos, tl lbnos), apsns)
   140   in T=U end;
   141 
   142 (*************************************************************************
   143    Create a "split_pack".
   144 
   145    thm   : the relevant split-theorem, i.e. P(...) == rhs , where P(...)
   146            is of the form
   147            P( Const(key,...) $ t_1 $ ... $ t_n )      (e.g. key = "if")
   148    T     : type of P(...)
   149    T'    : type of term to be scanned
   150    n     : number of arguments expected by Const(key,...)
   151    ts    : list of arguments actually found
   152    apsns : list of tuples of the form (T,U,pos), one tuple for each
   153            abstraction that is encountered on the way to the position where 
   154            Const(key, ...) $ ...  occurs, where
   155            T   : type of the variable bound by the abstraction
   156            U   : type of the abstraction's body
   157            pos : "path" leading to the body of the abstraction
   158    pos   : "path" leading to the position where Const(key, ...) $ ...  occurs.
   159    TB    : type of  Const(key,...) $ t_1 $ ... $ t_n
   160    t     : the term Const(key,...) $ t_1 $ ... $ t_n
   161 
   162    A split pack is a tuple of the form
   163    (thm, apsns, pos, TB, tt)
   164    Note : apsns is reversed, so that the outermost quantifier's position
   165           comes first ! If the terms in ts don't contain variables bound
   166           by other than meta-quantifiers, apsns is empty, because no further
   167           lifting is required before applying the split-theorem.
   168 ******************************************************************************) 
   169 
   170 fun mk_split_pack(thm, T, T', n, ts, apsns, pos, TB, t) =
   171   if n > length ts then []
   172   else let val lev = length apsns
   173            val lbnos = foldl add_lbnos ([],take(n,ts))
   174            val flbnos = filter (fn i => i < lev) lbnos
   175            val tt = incr_boundvars (~lev) t
   176        in if null flbnos then
   177             if T = T' then [(thm,[],pos,TB,tt)] else []
   178           else if type_test(T,flbnos,apsns) then [(thm, rev apsns,pos,TB,tt)]
   179                else []
   180        end;
   181 
   182 
   183 (****************************************************************************
   184    Recursively scans term for occurences of Const(key,...) $ ...
   185    Returns a list of "split-packs" (one for each occurence of Const(key,...) )
   186 
   187    cmap : association list of split-theorems that should be tried.
   188           The elements have the format (key,(thm,T,n)) , where
   189           key : the theorem's key constant ( Const(key,...) $ ... )
   190           thm : the theorem itself
   191           T   : type of P( Const(key,...) $ ... )
   192           n   : number of arguments expected by Const(key,...)
   193    Ts   : types of parameters
   194    t    : the term to be scanned
   195 ******************************************************************************)
   196 
   197 fun split_posns cmap sg Ts t =
   198   let
   199     val T' = fastype_of1 (Ts, t);
   200     fun posns Ts pos apsns (Abs (_, T, t)) =
   201           let val U = fastype_of1 (T::Ts,t)
   202           in posns (T::Ts) (0::pos) ((T, U, pos)::apsns) t end
   203       | posns Ts pos apsns t =
   204           let
   205             val (h, ts) = strip_comb t
   206             fun iter((i, a), t) = (i+1, (posns Ts (i::pos) apsns t) @ a);
   207             val a = case h of
   208               Const(c, cT) =>
   209                 let fun find [] = []
   210                       | find ((gcT, thm, T, n)::tups) =
   211                           if Type.typ_instance(Sign.tsig_of sg, cT, gcT)
   212                           then
   213                             let val t2 = list_comb (h, take (n, ts))
   214                             in mk_split_pack(thm,T,T',n,ts,apsns,pos,type_of1(Ts,t2),t2)
   215                             end
   216                           else find tups
   217                 in find (assocs cmap c) end
   218             | _ => []
   219           in snd(foldl iter ((0, a), ts)) end
   220   in posns Ts [] [] t end;
   221 
   222 
   223 fun nth_subgoal i thm = nth_elem(i-1,prems_of thm);
   224 
   225 fun shorter((_,ps,pos,_,_),(_,qs,qos,_,_)) =
   226   prod_ord (int_ord o pairself length) (order o pairself length)
   227     ((ps, pos), (qs, qos));
   228 
   229 
   230 
   231 (************************************************************
   232    call split_posns with appropriate parameters
   233 *************************************************************)
   234 
   235 fun select cmap state i =
   236   let val sg = #sign(rep_thm state)
   237       val goali = nth_subgoal i state
   238       val Ts = rev(map #2 (Logic.strip_params goali))
   239       val _ $ t $ _ = Logic.strip_assums_concl goali;
   240   in (Ts,t, sort shorter (split_posns cmap sg Ts t)) end;
   241 
   242 
   243 (*************************************************************
   244    instantiate lift theorem
   245 
   246    if t is of the form
   247    ... ( Const(...,...) $ Abs( .... ) ) ...
   248    then
   249    P = %a.  ... ( Const(...,...) $ a ) ...
   250    where a has type T --> U
   251 
   252    Ts      : types of parameters
   253    t       : lefthand side of meta-equality in subgoal
   254              the split theorem is applied to (see cmap)
   255    T,U,pos : see mk_split_pack
   256    state   : current proof state
   257    lift    : the lift theorem
   258    i       : no. of subgoal
   259 **************************************************************)
   260 
   261 fun inst_lift Ts t (T, U, pos) state i =
   262   let
   263     val cert = cterm_of (sign_of_thm state);
   264     val cntxt = mk_cntxt Ts t pos (T --> U) (#maxidx(rep_thm trlift));    
   265   in cterm_instantiate [(cert P, cert cntxt)] trlift
   266   end;
   267 
   268 
   269 (*************************************************************
   270    instantiate split theorem
   271 
   272    Ts    : types of parameters
   273    t     : lefthand side of meta-equality in subgoal
   274            the split theorem is applied to (see cmap)
   275    tt    : the term  Const(key,..) $ ...
   276    thm   : the split theorem
   277    TB    : type of body of P(...)
   278    state : current proof state
   279    i     : number of subgoal
   280 **************************************************************)
   281 
   282 fun inst_split Ts t tt thm TB state i =
   283   let 
   284     val thm' = Thm.lift_rule (state, i) thm;
   285     val (P, _) = strip_comb (fst (Logic.dest_equals
   286       (Logic.strip_assums_concl (#prop (rep_thm thm')))));
   287     val cert = cterm_of (sign_of_thm state);
   288     val cntxt = mk_cntxt_splitthm t tt TB;
   289     val abss = foldl (fn (t, T) => Abs ("", T, t));
   290   in cterm_instantiate [(cert P, cert (abss (cntxt, Ts)))] thm'
   291   end;
   292 
   293 
   294 (*****************************************************************************
   295    The split-tactic
   296    
   297    splits : list of split-theorems to be tried
   298    i      : number of subgoal the tactic should be applied to
   299 *****************************************************************************)
   300 
   301 fun split_tac [] i = no_tac
   302   | split_tac splits i =
   303   let val splits = map Data.mk_eq splits;
   304       fun add_thm(cmap,thm) =
   305             (case concl_of thm of _$(t as _$lhs)$_ =>
   306                (case strip_comb lhs of (Const(a,aT),args) =>
   307                   let val info = (aT,thm,fastype_of t,length args)
   308                   in case assoc(cmap,a) of
   309                        Some infos => overwrite(cmap,(a,info::infos))
   310                      | None => (a,[info])::cmap
   311                   end
   312                 | _ => split_format_err())
   313              | _ => split_format_err())
   314       val cmap = foldl add_thm ([],splits);
   315       fun lift_tac Ts t p st = rtac (inst_lift Ts t p st i) i st
   316       fun lift_split_tac state =
   317             let val (Ts, t, splits) = select cmap state i
   318             in case splits of
   319                  [] => no_tac state
   320                | (thm, apsns, pos, TB, tt)::_ =>
   321                    (case apsns of
   322                       [] => compose_tac (false, inst_split Ts t tt thm TB state i, 0) i state
   323                     | p::_ => EVERY [lift_tac Ts t p,
   324                                      rtac reflexive_thm (i+1),
   325                                      lift_split_tac] state)
   326             end
   327   in COND (has_fewer_prems i) no_tac 
   328           (rtac meta_iffD i THEN lift_split_tac)
   329   end;
   330 
   331 in split_tac end;
   332 
   333 
   334 val split_tac        = mk_case_split_tac              int_ord;
   335 
   336 val split_inside_tac = mk_case_split_tac (rev_order o int_ord);
   337 
   338 
   339 (*****************************************************************************
   340    The split-tactic for premises
   341    
   342    splits : list of split-theorems to be tried
   343 ****************************************************************************)
   344 fun split_asm_tac []     = K no_tac
   345   | split_asm_tac splits = 
   346 
   347   let val cname_list = map (fst o split_thm_info) splits;
   348       fun is_case (a,_) = a mem cname_list;
   349       fun tac (t,i) = 
   350 	  let val n = find_index (exists_Const is_case) 
   351 				 (Logic.strip_assums_hyp t);
   352 	      fun first_prem_is_disj (Const ("==>", _) $ (Const ("Trueprop", _)
   353 				 $ (Const (s, _) $ _ $ _ )) $ _ ) = (s=const_or)
   354 	      |   first_prem_is_disj (Const("all",_)$Abs(_,_,t)) = 
   355 					first_prem_is_disj t
   356 	      |   first_prem_is_disj _ = false;
   357       (* does not work properly if the split variable is bound by a quantfier *)
   358 	      fun flat_prems_tac i = SUBGOAL (fn (t,i) => 
   359 			   (if first_prem_is_disj t
   360 			    then EVERY[etac Data.disjE i,rotate_tac ~1 i,
   361 				       rotate_tac ~1  (i+1),
   362 				       flat_prems_tac (i+1)]
   363 			    else all_tac) 
   364 			   THEN REPEAT (eresolve_tac [Data.conjE,Data.exE] i)
   365 			   THEN REPEAT (dresolve_tac [Data.notnotD]   i)) i;
   366 	  in if n<0 then no_tac else DETERM (EVERY'
   367 		[rotate_tac n, etac Data.contrapos2,
   368 		 split_tac splits, 
   369 		 rotate_tac ~1, etac Data.contrapos, rotate_tac ~1, 
   370 		 flat_prems_tac] i)
   371 	  end;
   372   in SUBGOAL tac
   373   end;
   374 
   375 
   376 
   377 (** declare split rules **)
   378 
   379 (* addsplits / delsplits *)
   380 
   381 fun split_name name asm = "split " ^ name ^ (if asm then " asm" else "");
   382 
   383 fun ss addsplits splits =
   384   let fun addsplit (ss,split) =
   385         let val (name,asm) = split_thm_info split
   386         in Simplifier.addloop(ss,(split_name name asm,
   387 		       (if asm then split_asm_tac else split_tac) [split])) end
   388   in foldl addsplit (ss,splits) end;
   389 
   390 fun ss delsplits splits =
   391   let fun delsplit(ss,split) =
   392         let val (name,asm) = split_thm_info split
   393         in Simplifier.delloop(ss,split_name name asm)
   394   end in foldl delsplit (ss,splits) end;
   395 
   396 fun Addsplits splits = (Simplifier.simpset_ref() := 
   397 			Simplifier.simpset() addsplits splits);
   398 fun Delsplits splits = (Simplifier.simpset_ref() := 
   399 			Simplifier.simpset() delsplits splits);
   400 
   401 
   402 (* attributes *)
   403 
   404 val splitN = "split";
   405 
   406 val split_add_global = Simplifier.change_global_ss (op addsplits);
   407 val split_del_global = Simplifier.change_global_ss (op delsplits);
   408 val split_add_local = Simplifier.change_local_ss (op addsplits);
   409 val split_del_local = Simplifier.change_local_ss (op delsplits);
   410 
   411 val split_attr =
   412  (Attrib.add_del_args split_add_global split_del_global,
   413   Attrib.add_del_args split_add_local split_del_local);
   414 
   415 
   416 (* methods *)
   417 
   418 val split_modifiers =
   419  [Args.$$$ splitN -- Args.colon >> K ((I, split_add_local): Method.modifier),
   420   Args.$$$ splitN -- Args.$$$ "add" -- Args.colon >> K (I, split_add_local),
   421   Args.$$$ splitN -- Args.$$$ "del" -- Args.colon >> K (I, split_del_local)];
   422 
   423 val split_meth = Method.thms_args (Method.SIMPLE_METHOD' HEADGOAL o split_tac);
   424 
   425 
   426 
   427 (** theory setup **)
   428 
   429 val setup =
   430  [Attrib.add_attributes [(splitN, split_attr, "declare splitter rule")],
   431   Method.add_methods [(splitN, split_meth, "apply splitter rule")]];
   432 
   433 end;