src/HOL/Import/shuffler.ML
author wenzelm
Tue Jul 04 19:49:49 2006 +0200 (2006-07-04)
changeset 19998 c8018518e112
parent 18728 6790126ab5f6
child 20071 8f3e1ddb50e6
permissions -rw-r--r--
Thm.varifyT;
     1 (*  Title:      HOL/Import/shuffler.ML
     2     ID:         $Id$
     3     Author:     Sebastian Skalberg, TU Muenchen
     4 
     5 Package for proving two terms equal by normalizing (hence the
     6 "shuffler" name).  Uses the simplifier for the normalization.
     7 *)
     8 
     9 signature Shuffler =
    10 sig
    11     val debug      : bool ref
    12 
    13     val norm_term  : theory -> term -> thm
    14     val make_equal : theory -> term -> term -> thm option
    15     val set_prop   : theory -> term -> (string * thm) list -> (string * thm) option
    16 
    17     val find_potential: theory -> term -> (string * thm) list
    18 
    19     val gen_shuffle_tac: theory -> bool -> (string * thm) list -> int -> tactic
    20 
    21     val shuffle_tac: (string * thm) list -> int -> tactic
    22     val search_tac : (string * thm) list -> int -> tactic
    23 
    24     val print_shuffles: theory -> unit
    25 
    26     val add_shuffle_rule: thm -> theory -> theory
    27     val shuffle_attr: attribute
    28 
    29     val setup      : theory -> theory
    30 end
    31 
    32 structure Shuffler :> Shuffler =
    33 struct
    34 
    35 val debug = ref false
    36 
    37 fun if_debug f x = if !debug then f x else ()
    38 val message = if_debug writeln
    39 
    40 (*Prints exceptions readably to users*)
    41 fun print_sign_exn_unit sign e = 
    42   case e of
    43      THM (msg,i,thms) =>
    44 	 (writeln ("Exception THM " ^ string_of_int i ^ " raised:\n" ^ msg);
    45 	  List.app print_thm thms)
    46    | THEORY (msg,thys) =>
    47 	 (writeln ("Exception THEORY raised:\n" ^ msg);
    48 	  List.app (writeln o Context.str_of_thy) thys)
    49    | TERM (msg,ts) =>
    50 	 (writeln ("Exception TERM raised:\n" ^ msg);
    51 	  List.app (writeln o Sign.string_of_term sign) ts)
    52    | TYPE (msg,Ts,ts) =>
    53 	 (writeln ("Exception TYPE raised:\n" ^ msg);
    54 	  List.app (writeln o Sign.string_of_typ sign) Ts;
    55 	  List.app (writeln o Sign.string_of_term sign) ts)
    56    | e => raise e
    57 
    58 (*Prints an exception, then fails*)
    59 fun print_sign_exn sign e = (print_sign_exn_unit sign e; raise e)
    60 
    61 val string_of_thm = Library.setmp print_mode [] string_of_thm;
    62 val string_of_cterm = Library.setmp print_mode [] string_of_cterm;
    63 
    64 fun mk_meta_eq th =
    65     (case concl_of th of
    66 	 Const("Trueprop",_) $ (Const("op =",_) $ _ $ _) => th RS eq_reflection
    67        | Const("==",_) $ _ $ _ => th
    68        | _ => raise THM("Not an equality",0,[th]))
    69     handle _ => raise THM("Couldn't make meta equality",0,[th])
    70 				   
    71 fun mk_obj_eq th =
    72     (case concl_of th of
    73 	 Const("Trueprop",_) $ (Const("op =",_) $ _ $ _) => th
    74        | Const("==",_) $ _ $ _ => th RS meta_eq_to_obj_eq
    75        | _ => raise THM("Not an equality",0,[th]))
    76     handle _ => raise THM("Couldn't make object equality",0,[th])
    77 
    78 structure ShuffleDataArgs: THEORY_DATA_ARGS =
    79 struct
    80 val name = "HOL/shuffles"
    81 type T = thm list
    82 val empty = []
    83 val copy = I
    84 val extend = I
    85 fun merge _ = Library.gen_union Thm.eq_thm
    86 fun print sg thms =
    87     Pretty.writeln (Pretty.big_list "Shuffle theorems:"
    88 				    (map Display.pretty_thm thms))
    89 end
    90 
    91 structure ShuffleData = TheoryDataFun(ShuffleDataArgs)
    92 
    93 val weaken =
    94     let
    95 	val cert = cterm_of (sign_of ProtoPure.thy)
    96 	val P = Free("P",propT)
    97 	val Q = Free("Q",propT)
    98 	val PQ = Logic.mk_implies(P,Q)
    99 	val PPQ = Logic.mk_implies(P,PQ)
   100 	val cP = cert P
   101 	val cQ = cert Q
   102 	val cPQ = cert PQ
   103 	val cPPQ = cert PPQ
   104 	val th1 = assume cPQ |> implies_intr_list [cPQ,cP]
   105 	val th3 = assume cP
   106 	val th4 = implies_elim_list (assume cPPQ) [th3,th3]
   107 				    |> implies_intr_list [cPPQ,cP]
   108     in
   109 	equal_intr th4 th1 |> standard
   110     end
   111 
   112 val imp_comm =
   113     let
   114 	val cert = cterm_of (sign_of ProtoPure.thy)
   115 	val P = Free("P",propT)
   116 	val Q = Free("Q",propT)
   117 	val R = Free("R",propT)
   118 	val PQR = Logic.mk_implies(P,Logic.mk_implies(Q,R))
   119 	val QPR = Logic.mk_implies(Q,Logic.mk_implies(P,R))
   120 	val cP = cert P
   121 	val cQ = cert Q
   122 	val cPQR = cert PQR
   123 	val cQPR = cert QPR
   124 	val th1 = implies_elim_list (assume cPQR) [assume cP,assume cQ]
   125 				    |> implies_intr_list [cPQR,cQ,cP]
   126 	val th2 = implies_elim_list (assume cQPR) [assume cQ,assume cP]
   127 				    |> implies_intr_list [cQPR,cP,cQ]
   128     in
   129 	equal_intr th1 th2 |> standard
   130     end
   131 
   132 val def_norm =
   133     let
   134 	val cert = cterm_of (sign_of ProtoPure.thy)
   135 	val aT = TFree("'a",[])
   136 	val bT = TFree("'b",[])
   137 	val v = Free("v",aT)
   138 	val P = Free("P",aT-->bT)
   139 	val Q = Free("Q",aT-->bT)
   140 	val cvPQ = cert (list_all ([("v",aT)],Logic.mk_equals(P $ Bound 0,Q $ Bound 0)))
   141 	val cPQ = cert (Logic.mk_equals(P,Q))
   142 	val cv = cert v
   143 	val rew = assume cvPQ
   144 			 |> forall_elim cv
   145 			 |> abstract_rule "v" cv
   146 	val (lhs,rhs) = Logic.dest_equals(concl_of rew)
   147 	val th1 = transitive (transitive
   148 				  (eta_conversion (cert lhs) |> symmetric)
   149 				  rew)
   150 			     (eta_conversion (cert rhs))
   151 			     |> implies_intr cvPQ
   152 	val th2 = combination (assume cPQ) (reflexive cv)
   153 			      |> forall_intr cv
   154 			      |> implies_intr cPQ
   155     in
   156 	equal_intr th1 th2 |> standard
   157     end
   158 
   159 val all_comm =
   160     let
   161 	val cert = cterm_of (sign_of ProtoPure.thy)
   162 	val xT = TFree("'a",[])
   163 	val yT = TFree("'b",[])
   164 	val P = Free("P",xT-->yT-->propT)
   165 	val lhs = all xT $ (Abs("x",xT,all yT $ (Abs("y",yT,P $ Bound 1 $ Bound 0))))
   166 	val rhs = all yT $ (Abs("y",yT,all xT $ (Abs("x",xT,P $ Bound 0 $ Bound 1))))
   167 	val cl = cert lhs
   168 	val cr = cert rhs
   169 	val cx = cert (Free("x",xT))
   170 	val cy = cert (Free("y",yT))
   171 	val th1 = assume cr
   172 			 |> forall_elim_list [cy,cx]
   173 			 |> forall_intr_list [cx,cy]
   174 			 |> implies_intr cr
   175 	val th2 = assume cl
   176 			 |> forall_elim_list [cx,cy]
   177 			 |> forall_intr_list [cy,cx]
   178 			 |> implies_intr cl
   179     in
   180 	equal_intr th1 th2 |> standard
   181     end
   182 
   183 val equiv_comm =
   184     let
   185 	val cert = cterm_of (sign_of ProtoPure.thy)
   186 	val T    = TFree("'a",[])
   187 	val t    = Free("t",T)
   188 	val u    = Free("u",T)
   189 	val ctu  = cert (Logic.mk_equals(t,u))
   190 	val cut  = cert (Logic.mk_equals(u,t))
   191 	val th1  = assume ctu |> symmetric |> implies_intr ctu
   192 	val th2  = assume cut |> symmetric |> implies_intr cut
   193     in
   194 	equal_intr th1 th2 |> standard
   195     end
   196 
   197 (* This simplification procedure rewrites !!x y. P x y
   198 deterministicly, in order for the normalization function, defined
   199 below, to handle nested quantifiers robustly *)
   200 
   201 local
   202 
   203 exception RESULT of int
   204 
   205 fun find_bound n (Bound i) = if i = n then raise RESULT 0
   206 			     else if i = n+1 then raise RESULT 1
   207 			     else ()
   208   | find_bound n (t $ u) = (find_bound n t; find_bound n u)
   209   | find_bound n (Abs(_,_,t)) = find_bound (n+1) t
   210   | find_bound _ _ = ()
   211 
   212 fun swap_bound n (Bound i) = if i = n then Bound (n+1)
   213 			     else if i = n+1 then Bound n
   214 			     else Bound i
   215   | swap_bound n (t $ u) = (swap_bound n t $ swap_bound n u)
   216   | swap_bound n (Abs(x,xT,t)) = Abs(x,xT,swap_bound (n+1) t)
   217   | swap_bound n t = t
   218 
   219 fun rew_th sg (xv as (x,xT)) (yv as (y,yT)) t =
   220     let
   221 	val lhs = list_all ([xv,yv],t)
   222 	val rhs = list_all ([yv,xv],swap_bound 0 t)
   223 	val rew = Logic.mk_equals (lhs,rhs)
   224 	val init = trivial (cterm_of sg rew)
   225     in
   226 	(all_comm RS init handle e => (message "rew_th"; OldGoals.print_exn e))
   227     end
   228 
   229 fun quant_rewrite sg assumes (t as Const("all",T1) $ (Abs(x,xT,Const("all",T2) $ Abs(y,yT,body)))) =
   230     let
   231 	val res = (find_bound 0 body;2) handle RESULT i => i
   232     in
   233 	case res of
   234 	    0 => SOME (rew_th sg (x,xT) (y,yT) body)
   235 	  | 1 => if string_ord(y,x) = LESS
   236 		 then
   237 		     let
   238 			 val newt = Const("all",T1) $ (Abs(y,xT,Const("all",T2) $ Abs(x,yT,body)))
   239 			 val t_th    = reflexive (cterm_of sg t)
   240 			 val newt_th = reflexive (cterm_of sg newt)
   241 		     in
   242 			 SOME (transitive t_th newt_th)
   243 		     end
   244 		 else NONE
   245 	  | _ => error "norm_term (quant_rewrite) internal error"
   246      end
   247   | quant_rewrite _ _ _ = (warning "quant_rewrite: Unknown lhs"; NONE)
   248 
   249 fun freeze_thaw_term t =
   250     let
   251 	val tvars = term_tvars t
   252 	val tfree_names = add_term_tfree_names(t,[])
   253 	val (type_inst,_) =
   254 	    Library.foldl (fn ((inst,used),(w as (v,_),S)) =>
   255 		      let
   256 			  val v' = variant used v
   257 		      in
   258 			  ((w,TFree(v',S))::inst,v'::used)
   259 		      end)
   260 		  (([],tfree_names),tvars)
   261 	val t' = subst_TVars type_inst t
   262     in
   263 	(t',map (fn (w,TFree(v,S)) => (v,TVar(w,S))
   264 		  | _ => error "Internal error in Shuffler.freeze_thaw") type_inst)
   265     end
   266 
   267 fun inst_tfrees sg [] thm = thm
   268   | inst_tfrees sg ((name,U)::rest) thm = 
   269     let
   270 	val cU = ctyp_of sg U
   271 	val tfrees = add_term_tfrees (prop_of thm,[])
   272 	val (rens, thm') = Thm.varifyT'
   273     (gen_rem (op = o apfst fst) (tfrees, name)) thm
   274 	val mid = 
   275 	    case rens of
   276 		[] => thm'
   277 	      | [((_, S), idx)] => instantiate
   278             ([(ctyp_of sg (TVar (idx, S)), cU)], []) thm'
   279 	      | _ => error "Shuffler.inst_tfrees internal error"
   280     in
   281 	inst_tfrees sg rest mid
   282     end
   283 
   284 fun is_Abs (Abs _) = true
   285   | is_Abs _ = false
   286 
   287 fun eta_redex (t $ Bound 0) =
   288     let
   289 	fun free n (Bound i) = i = n
   290 	  | free n (t $ u) = free n t orelse free n u
   291 	  | free n (Abs(_,_,t)) = free (n+1) t
   292 	  | free n _ = false
   293     in
   294 	not (free 0 t)
   295     end
   296   | eta_redex _ = false
   297 
   298 fun eta_contract sg assumes origt =
   299     let
   300 	val (typet,Tinst) = freeze_thaw_term origt
   301 	val (init,thaw) = freeze_thaw (reflexive (cterm_of sg typet))
   302 	val final = inst_tfrees sg Tinst o thaw
   303 	val t = #1 (Logic.dest_equals (prop_of init))
   304 	val _ =
   305 	    let
   306 		val lhs = #1 (Logic.dest_equals (prop_of (final init)))
   307 	    in
   308 		if not (lhs aconv origt)
   309 		then (writeln "Something is utterly wrong: (orig,lhs,frozen type,t,tinst)";
   310 		      writeln (string_of_cterm (cterm_of sg origt));
   311 		      writeln (string_of_cterm (cterm_of sg lhs));
   312 		      writeln (string_of_cterm (cterm_of sg typet));
   313 		      writeln (string_of_cterm (cterm_of sg t));
   314 		      app (fn (n,T) => writeln (n ^ ": " ^ (string_of_ctyp (ctyp_of sg T)))) Tinst;
   315 		      writeln "done")
   316 		else ()
   317 	    end
   318     in
   319 	case t of
   320 	    Const("all",_) $ (Abs(x,xT,Const("==",eqT) $ P $ Q)) =>
   321 	    ((if eta_redex P andalso eta_redex Q
   322 	      then
   323 		  let
   324 		      val cert = cterm_of sg
   325 		      val v = Free(variant (add_term_free_names(t,[])) "v",xT)
   326 		      val cv = cert v
   327 		      val ct = cert t
   328 		      val th = (assume ct)
   329 				   |> forall_elim cv
   330 				   |> abstract_rule x cv
   331 		      val ext_th = eta_conversion (cert (Abs(x,xT,P)))
   332 		      val th' = transitive (symmetric ext_th) th
   333 		      val cu = cert (prop_of th')
   334 		      val uth = combination (assume cu) (reflexive cv)
   335 		      val uth' = (beta_conversion false (cert (Abs(x,xT,Q) $ v)))
   336 				     |> transitive uth
   337 				     |> forall_intr cv
   338 				     |> implies_intr cu
   339 		      val rew_th = equal_intr (th' |> implies_intr ct) uth'
   340 		      val res = final rew_th
   341 		      val lhs = (#1 (Logic.dest_equals (prop_of res)))
   342 		  in
   343 		       SOME res
   344 		  end
   345 	      else NONE)
   346 	     handle e => OldGoals.print_exn e)
   347 	  | _ => NONE
   348        end
   349 
   350 fun beta_fun sg assume t =
   351     SOME (beta_conversion true (cterm_of sg t))
   352 
   353 val meta_sym_rew = thm "refl"
   354 
   355 fun equals_fun sg assume t =
   356     case t of
   357 	Const("op ==",_) $ u $ v => if Term.term_ord (u,v) = LESS then SOME (meta_sym_rew) else NONE
   358       | _ => NONE
   359 
   360 fun eta_expand sg assumes origt =
   361     let
   362 	val (typet,Tinst) = freeze_thaw_term origt
   363 	val (init,thaw) = freeze_thaw (reflexive (cterm_of sg typet))
   364 	val final = inst_tfrees sg Tinst o thaw
   365 	val t = #1 (Logic.dest_equals (prop_of init))
   366 	val _ =
   367 	    let
   368 		val lhs = #1 (Logic.dest_equals (prop_of (final init)))
   369 	    in
   370 		if not (lhs aconv origt)
   371 		then (writeln "Something is utterly wrong: (orig,lhs,frozen type,t,tinst)";
   372 		      writeln (string_of_cterm (cterm_of sg origt));
   373 		      writeln (string_of_cterm (cterm_of sg lhs));
   374 		      writeln (string_of_cterm (cterm_of sg typet));
   375 		      writeln (string_of_cterm (cterm_of sg t));
   376 		      app (fn (n,T) => writeln (n ^ ": " ^ (string_of_ctyp (ctyp_of sg T)))) Tinst;
   377 		      writeln "done")
   378 		else ()
   379 	    end
   380     in
   381 	case t of
   382 	    Const("==",T) $ P $ Q =>
   383 	    if is_Abs P orelse is_Abs Q
   384 	    then (case domain_type T of
   385 		      Type("fun",[aT,bT]) =>
   386 		      let
   387 			  val cert = cterm_of sg
   388 			  val vname = variant (add_term_free_names(t,[])) "v"
   389 			  val v = Free(vname,aT)
   390 			  val cv = cert v
   391 			  val ct = cert t
   392 			  val th1 = (combination (assume ct) (reflexive cv))
   393 					|> forall_intr cv
   394 					|> implies_intr ct
   395 			  val concl = cert (concl_of th1)
   396 			  val th2 = (assume concl)
   397 					|> forall_elim cv
   398 					|> abstract_rule vname cv
   399 			  val (lhs,rhs) = Logic.dest_equals (prop_of th2)
   400 			  val elhs = eta_conversion (cert lhs)
   401 			  val erhs = eta_conversion (cert rhs)
   402 			  val th2' = transitive
   403 					 (transitive (symmetric elhs) th2)
   404 					 erhs
   405 			  val res = equal_intr th1 (th2' |> implies_intr concl)
   406 			  val res' = final res
   407 		      in
   408 			  SOME res'
   409 		      end
   410 		    | _ => NONE)
   411 	    else NONE
   412 	  | _ => (error ("Bad eta_expand argument" ^ (string_of_cterm (cterm_of sg t))); NONE)
   413     end
   414     handle e => (writeln "eta_expand internal error"; OldGoals.print_exn e)
   415 
   416 fun mk_tfree s = TFree("'"^s,[])
   417 val xT = mk_tfree "a"
   418 val yT = mk_tfree "b"
   419 val P  = Var(("P",0),xT-->yT-->propT)
   420 val Q  = Var(("Q",0),xT-->yT)
   421 val R  = Var(("R",0),xT-->yT)
   422 val S  = Var(("S",0),xT)
   423 val S'  = Var(("S'",0),xT)
   424 in
   425 fun beta_simproc sg = Simplifier.simproc_i
   426 		      sg
   427 		      "Beta-contraction"
   428 		      [Abs("x",xT,Q) $ S]
   429 		      beta_fun
   430 
   431 fun equals_simproc sg = Simplifier.simproc_i
   432 		      sg
   433 		      "Ordered rewriting of meta equalities"
   434 		      [Const("op ==",xT) $ S $ S']
   435 		      equals_fun
   436 
   437 fun quant_simproc sg = Simplifier.simproc_i
   438 			   sg
   439 			   "Ordered rewriting of nested quantifiers"
   440 			   [all xT $ (Abs("x",xT,all yT $ (Abs("y",yT,P $ Bound 1 $ Bound 0))))]
   441 			   quant_rewrite
   442 fun eta_expand_simproc sg = Simplifier.simproc_i
   443 			 sg
   444 			 "Smart eta-expansion by equivalences"
   445 			 [Logic.mk_equals(Q,R)]
   446 			 eta_expand
   447 fun eta_contract_simproc sg = Simplifier.simproc_i
   448 			 sg
   449 			 "Smart handling of eta-contractions"
   450 			 [all xT $ (Abs("x",xT,Logic.mk_equals(Q $ Bound 0,R $ Bound 0)))]
   451 			 eta_contract
   452 end
   453 
   454 (* Disambiguates the names of bound variables in a term, returning t
   455 == t' where all the names of bound variables in t' are unique *)
   456 
   457 fun disamb_bound sg t =
   458     let
   459 	
   460 	fun F (t $ u,idx) =
   461 	    let
   462 		val (t',idx') = F (t,idx)
   463 		val (u',idx'') = F (u,idx')
   464 	    in
   465 		(t' $ u',idx'')
   466 	    end
   467 	  | F (Abs(x,xT,t),idx) =
   468 	    let
   469 		val x' = "x" ^ (LargeInt.toString idx) (* amazing *)
   470 		val (t',idx') = F (t,idx+1)
   471 	    in
   472 		(Abs(x',xT,t'),idx')
   473 	    end
   474 	  | F arg = arg
   475 	val (t',_) = F (t,0)
   476 	val ct = cterm_of sg t
   477 	val ct' = cterm_of sg t'
   478 	val res = transitive (reflexive ct) (reflexive ct')
   479 	val _ = message ("disamb_term: " ^ (string_of_thm res))
   480     in
   481 	res
   482     end
   483 
   484 (* Transforms a term t to some normal form t', returning the theorem t
   485 == t'.  This is originally a help function for make_equal, but might
   486 be handy in its own right, for example for indexing terms. *)
   487 
   488 fun norm_term thy t =
   489     let
   490 	val sg = sign_of thy
   491 
   492 	val norms = ShuffleData.get thy
   493 	val ss = Simplifier.theory_context thy empty_ss
   494           setmksimps single
   495 	  addsimps (map (Thm.transfer sg) norms)
   496 	fun chain f th =
   497 	    let
   498                 val rhs = snd (dest_equals (cprop_of th))
   499       	    in
   500 		transitive th (f rhs)
   501 	    end
   502 
   503 	val th =
   504 	    t |> disamb_bound sg
   505 	      |> chain (Simplifier.full_rewrite
   506 			    (ss addsimprocs [quant_simproc sg,eta_expand_simproc sg,eta_contract_simproc sg]))
   507 	      |> chain eta_conversion
   508 	      |> strip_shyps
   509 	val _ = message ("norm_term: " ^ (string_of_thm th))
   510     in
   511 	th
   512     end
   513     handle e => (writeln "norm_term internal error"; print_sign_exn (sign_of thy) e)
   514 
   515 
   516 (* Closes a theorem with respect to free and schematic variables (does
   517 not touch type variables, though). *)
   518 
   519 fun close_thm th =
   520     let
   521 	val sg = sign_of_thm th
   522 	val c = prop_of th
   523 	val vars = add_term_frees (c,add_term_vars(c,[]))
   524     in
   525 	Drule.forall_intr_list (map (cterm_of sg) vars) th
   526     end
   527     handle e => (writeln "close_thm internal error"; OldGoals.print_exn e)
   528 
   529 (* Normalizes a theorem's conclusion using norm_term. *)
   530 
   531 fun norm_thm thy th =
   532     let
   533 	val c = prop_of th
   534     in
   535 	equal_elim (norm_term thy c) th
   536     end
   537 
   538 (* make_equal sg t u tries to construct the theorem t == u under the
   539 signature sg.  If it succeeds, SOME (t == u) is returned, otherwise
   540 NONE is returned. *)
   541 
   542 fun make_equal sg t u =
   543     let
   544 	val t_is_t' = norm_term sg t
   545 	val u_is_u' = norm_term sg u
   546 	val th = transitive t_is_t' (symmetric u_is_u')
   547 	val _ = message ("make_equal: SOME " ^ (string_of_thm th))
   548     in
   549 	SOME th
   550     end
   551     handle e as THM _ => (message "make_equal: NONE";NONE)
   552 			 
   553 fun match_consts ignore t (* th *) =
   554     let
   555 	fun add_consts (Const (c, _), cs) =
   556 	    if c mem_string ignore
   557 	    then cs
   558 	    else c ins_string cs
   559 	  | add_consts (t $ u, cs) = add_consts (t, add_consts (u, cs))
   560 	  | add_consts (Abs (_, _, t), cs) = add_consts (t, cs)
   561 	  | add_consts (_, cs) = cs
   562 	val t_consts = add_consts(t,[])
   563     in
   564      fn (name,th) =>
   565 	let
   566 	    val th_consts = add_consts(prop_of th,[])
   567 	in
   568 	    eq_set(t_consts,th_consts)
   569 	end
   570     end
   571     
   572 val collect_ignored =
   573     foldr (fn (thm,cs) =>
   574 	      let
   575 		  val (lhs,rhs) = Logic.dest_equals (prop_of thm)
   576 		  val ignore_lhs = term_consts lhs \\ term_consts rhs
   577 		  val ignore_rhs = term_consts rhs \\ term_consts lhs
   578 	      in
   579 		  foldr (op ins_string) cs (ignore_lhs @ ignore_rhs)
   580 	      end)
   581 
   582 (* set_prop t thms tries to make a theorem with the proposition t from
   583 one of the theorems thms, by shuffling the propositions around.  If it
   584 succeeds, SOME theorem is returned, otherwise NONE.  *)
   585 
   586 fun set_prop thy t =
   587     let
   588 	val sg = sign_of thy
   589 	val vars = add_term_frees (t,add_term_vars (t,[]))
   590 	val closed_t = foldr (fn (v,body) => let val vT = type_of v
   591 					     in all vT $ (Abs("x",vT,abstract_over(v,body))) end) t vars
   592 	val rew_th = norm_term thy closed_t
   593 	val rhs = snd (dest_equals (cprop_of rew_th))
   594 
   595 	val shuffles = ShuffleData.get thy
   596 	fun process [] = NONE
   597 	  | process ((name,th)::thms) =
   598 	    let
   599 		val norm_th = Thm.varifyT (norm_thm thy (close_thm (Thm.transfer sg th)))
   600 		val triv_th = trivial rhs
   601 		val _ = message ("Shuffler.set_prop: Gluing together " ^ (string_of_thm norm_th) ^ " and " ^ (string_of_thm triv_th))
   602 		val mod_th = case Seq.pull (bicompose false (*true*) (false,norm_th,0) 1 triv_th) of
   603 				 SOME(th,_) => SOME th
   604 			       | NONE => NONE
   605 	    in
   606 		case mod_th of
   607 		    SOME mod_th =>
   608 		    let
   609 			val closed_th = equal_elim (symmetric rew_th) mod_th
   610 		    in
   611 			message ("Shuffler.set_prop succeeded by " ^ name);
   612 			SOME (name,forall_elim_list (map (cterm_of sg) vars) closed_th)
   613 		    end
   614 		  | NONE => process thms
   615 	    end
   616 	    handle e as THM _ => process thms
   617     in
   618 	fn thms =>
   619 	   case process thms of
   620 	       res as SOME (name,th) => if (prop_of th) aconv t
   621 					then res
   622 					else error "Internal error in set_prop"
   623 	     | NONE => NONE
   624     end
   625     handle e => (writeln "set_prop internal error"; OldGoals.print_exn e)
   626 
   627 fun find_potential thy t =
   628     let
   629 	val shuffles = ShuffleData.get thy
   630 	val ignored = collect_ignored [] shuffles
   631 	val rel_consts = term_consts t \\ ignored
   632 	val pot_thms = PureThy.thms_containing_consts thy rel_consts
   633     in
   634 	List.filter (match_consts ignored t) pot_thms
   635     end
   636 
   637 fun gen_shuffle_tac thy search thms i st =
   638     let
   639 	val _ = message ("Shuffling " ^ (string_of_thm st))
   640 	val t = List.nth(prems_of st,i-1)
   641 	val set = set_prop thy t
   642 	fun process_tac thms st =
   643 	    case set thms of
   644 		SOME (_,th) => Seq.of_list (compose (th,i,st))
   645 	      | NONE => Seq.empty
   646     in
   647 	(process_tac thms APPEND (if search
   648 				  then process_tac (find_potential thy t)
   649 				  else no_tac)) st
   650     end
   651 
   652 fun shuffle_tac thms i st =
   653     gen_shuffle_tac (the_context()) false thms i st
   654 
   655 fun search_tac thms i st =
   656     gen_shuffle_tac (the_context()) true thms i st
   657 
   658 fun shuffle_meth (thms:thm list) ctxt =
   659     let
   660 	val thy = ProofContext.theory_of ctxt
   661     in
   662 	Method.SIMPLE_METHOD' HEADGOAL (gen_shuffle_tac thy false (map (pair "") thms))
   663     end
   664 
   665 fun search_meth ctxt =
   666     let
   667 	val thy = ProofContext.theory_of ctxt
   668 	val prems = ProofContext.prems_of ctxt
   669     in
   670 	Method.SIMPLE_METHOD' HEADGOAL (gen_shuffle_tac thy true (map (pair "premise") prems))
   671     end
   672 
   673 val print_shuffles = ShuffleData.print
   674 
   675 fun add_shuffle_rule thm thy =
   676     let
   677 	val shuffles = ShuffleData.get thy
   678     in
   679 	if exists (curry Thm.eq_thm thm) shuffles
   680 	then (warning ((string_of_thm thm) ^ " already known to the shuffler");
   681 	      thy)
   682 	else ShuffleData.put (thm::shuffles) thy
   683     end
   684 
   685 val shuffle_attr = Thm.declaration_attribute (Context.map_theory o add_shuffle_rule);
   686 
   687 val setup =
   688   Method.add_method ("shuffle_tac",Method.thms_ctxt_args shuffle_meth,"solve goal by shuffling terms around") #>
   689   Method.add_method ("search_tac",Method.ctxt_args search_meth,"search for suitable theorems") #>
   690   ShuffleData.init #>
   691   add_shuffle_rule weaken #>
   692   add_shuffle_rule equiv_comm #>
   693   add_shuffle_rule imp_comm #>
   694   add_shuffle_rule Drule.norm_hhf_eq #>
   695   add_shuffle_rule Drule.triv_forall_equality #>
   696   Attrib.add_attributes [("shuffle_rule", Attrib.no_args shuffle_attr, "declare rule for shuffler")]
   697 
   698 end