src/HOL/Import/shuffler.ML
author wenzelm
Sat Mar 15 18:07:58 2008 +0100 (2008-03-15)
changeset 26277 461e11226111
parent 24634 38db11874724
child 26424 a6cad32a27b0
permissions -rw-r--r--
removed obsolete PureThy.thms_containing_consts;
skalberg@14620
     1
(*  Title:      HOL/Import/shuffler.ML
skalberg@14516
     2
    ID:         $Id$
skalberg@14516
     3
    Author:     Sebastian Skalberg, TU Muenchen
skalberg@14516
     4
skalberg@14516
     5
Package for proving two terms equal by normalizing (hence the
skalberg@14516
     6
"shuffler" name).  Uses the simplifier for the normalization.
skalberg@14516
     7
*)
skalberg@14516
     8
skalberg@14516
     9
signature Shuffler =
skalberg@14516
    10
sig
skalberg@14516
    11
    val debug      : bool ref
skalberg@14516
    12
skalberg@14516
    13
    val norm_term  : theory -> term -> thm
skalberg@14516
    14
    val make_equal : theory -> term -> term -> thm option
skalberg@14516
    15
    val set_prop   : theory -> term -> (string * thm) list -> (string * thm) option
skalberg@14516
    16
skalberg@14516
    17
    val find_potential: theory -> term -> (string * thm) list
skalberg@14516
    18
skalberg@14516
    19
    val gen_shuffle_tac: theory -> bool -> (string * thm) list -> int -> tactic
skalberg@14516
    20
skalberg@14516
    21
    val shuffle_tac: (string * thm) list -> int -> tactic
skalberg@14516
    22
    val search_tac : (string * thm) list -> int -> tactic
skalberg@14516
    23
skalberg@14516
    24
    val print_shuffles: theory -> unit
skalberg@14516
    25
skalberg@14516
    26
    val add_shuffle_rule: thm -> theory -> theory
wenzelm@18728
    27
    val shuffle_attr: attribute
skalberg@14516
    28
wenzelm@18708
    29
    val setup      : theory -> theory
skalberg@14516
    30
end
skalberg@14516
    31
skalberg@14516
    32
structure Shuffler :> Shuffler =
skalberg@14516
    33
struct
skalberg@14516
    34
skalberg@14516
    35
val debug = ref false
skalberg@14516
    36
skalberg@14516
    37
fun if_debug f x = if !debug then f x else ()
skalberg@14516
    38
val message = if_debug writeln
skalberg@14516
    39
skalberg@14516
    40
(*Prints exceptions readably to users*)
wenzelm@21588
    41
fun print_sign_exn_unit sign e =
skalberg@14516
    42
  case e of
skalberg@14516
    43
     THM (msg,i,thms) =>
wenzelm@21588
    44
         (writeln ("Exception THM " ^ string_of_int i ^ " raised:\n" ^ msg);
wenzelm@21588
    45
          List.app print_thm thms)
skalberg@14516
    46
   | THEORY (msg,thys) =>
wenzelm@21588
    47
         (writeln ("Exception THEORY raised:\n" ^ msg);
wenzelm@21588
    48
          List.app (writeln o Context.str_of_thy) thys)
skalberg@14516
    49
   | TERM (msg,ts) =>
wenzelm@21588
    50
         (writeln ("Exception TERM raised:\n" ^ msg);
wenzelm@21588
    51
          List.app (writeln o Sign.string_of_term sign) ts)
skalberg@14516
    52
   | TYPE (msg,Ts,ts) =>
wenzelm@21588
    53
         (writeln ("Exception TYPE raised:\n" ^ msg);
wenzelm@21588
    54
          List.app (writeln o Sign.string_of_typ sign) Ts;
wenzelm@21588
    55
          List.app (writeln o Sign.string_of_term sign) ts)
skalberg@14516
    56
   | e => raise e
skalberg@14516
    57
skalberg@14516
    58
(*Prints an exception, then fails*)
skalberg@14516
    59
fun print_sign_exn sign e = (print_sign_exn_unit sign e; raise e)
skalberg@14516
    60
wenzelm@24634
    61
val string_of_thm = PrintMode.setmp [] string_of_thm;
wenzelm@24634
    62
val string_of_cterm = PrintMode.setmp [] string_of_cterm;
skalberg@14516
    63
skalberg@14516
    64
fun mk_meta_eq th =
skalberg@14516
    65
    (case concl_of th of
wenzelm@21588
    66
         Const("Trueprop",_) $ (Const("op =",_) $ _ $ _) => th RS eq_reflection
skalberg@14516
    67
       | Const("==",_) $ _ $ _ => th
skalberg@14516
    68
       | _ => raise THM("Not an equality",0,[th]))
skalberg@14516
    69
    handle _ => raise THM("Couldn't make meta equality",0,[th])
wenzelm@21588
    70
skalberg@14516
    71
fun mk_obj_eq th =
skalberg@14516
    72
    (case concl_of th of
wenzelm@21588
    73
         Const("Trueprop",_) $ (Const("op =",_) $ _ $ _) => th
skalberg@14516
    74
       | Const("==",_) $ _ $ _ => th RS meta_eq_to_obj_eq
skalberg@14516
    75
       | _ => raise THM("Not an equality",0,[th]))
skalberg@14516
    76
    handle _ => raise THM("Couldn't make object equality",0,[th])
skalberg@14516
    77
wenzelm@22846
    78
structure ShuffleData = TheoryDataFun
wenzelm@22846
    79
(
wenzelm@22846
    80
  type T = thm list
wenzelm@22846
    81
  val empty = []
wenzelm@22846
    82
  val copy = I
wenzelm@22846
    83
  val extend = I
wenzelm@22846
    84
  fun merge _ = Library.gen_union Thm.eq_thm
wenzelm@22846
    85
)
skalberg@14516
    86
wenzelm@22846
    87
fun print_shuffles thy =
wenzelm@22846
    88
  Pretty.writeln (Pretty.big_list "Shuffle theorems:"
wenzelm@22846
    89
    (map Display.pretty_thm (ShuffleData.get thy)))
skalberg@14516
    90
skalberg@14516
    91
val weaken =
skalberg@14516
    92
    let
wenzelm@21588
    93
        val cert = cterm_of ProtoPure.thy
wenzelm@21588
    94
        val P = Free("P",propT)
wenzelm@21588
    95
        val Q = Free("Q",propT)
wenzelm@21588
    96
        val PQ = Logic.mk_implies(P,Q)
wenzelm@21588
    97
        val PPQ = Logic.mk_implies(P,PQ)
wenzelm@21588
    98
        val cP = cert P
wenzelm@21588
    99
        val cQ = cert Q
wenzelm@21588
   100
        val cPQ = cert PQ
wenzelm@21588
   101
        val cPPQ = cert PPQ
wenzelm@21588
   102
        val th1 = assume cPQ |> implies_intr_list [cPQ,cP]
wenzelm@21588
   103
        val th3 = assume cP
wenzelm@21588
   104
        val th4 = implies_elim_list (assume cPPQ) [th3,th3]
wenzelm@21588
   105
                                    |> implies_intr_list [cPPQ,cP]
skalberg@14516
   106
    in
wenzelm@21588
   107
        equal_intr th4 th1 |> standard
skalberg@14516
   108
    end
skalberg@14516
   109
skalberg@14516
   110
val imp_comm =
skalberg@14516
   111
    let
wenzelm@21588
   112
        val cert = cterm_of ProtoPure.thy
wenzelm@21588
   113
        val P = Free("P",propT)
wenzelm@21588
   114
        val Q = Free("Q",propT)
wenzelm@21588
   115
        val R = Free("R",propT)
wenzelm@21588
   116
        val PQR = Logic.mk_implies(P,Logic.mk_implies(Q,R))
wenzelm@21588
   117
        val QPR = Logic.mk_implies(Q,Logic.mk_implies(P,R))
wenzelm@21588
   118
        val cP = cert P
wenzelm@21588
   119
        val cQ = cert Q
wenzelm@21588
   120
        val cPQR = cert PQR
wenzelm@21588
   121
        val cQPR = cert QPR
wenzelm@21588
   122
        val th1 = implies_elim_list (assume cPQR) [assume cP,assume cQ]
wenzelm@21588
   123
                                    |> implies_intr_list [cPQR,cQ,cP]
wenzelm@21588
   124
        val th2 = implies_elim_list (assume cQPR) [assume cQ,assume cP]
wenzelm@21588
   125
                                    |> implies_intr_list [cQPR,cP,cQ]
skalberg@14516
   126
    in
wenzelm@21588
   127
        equal_intr th1 th2 |> standard
skalberg@14516
   128
    end
skalberg@14516
   129
skalberg@14516
   130
val def_norm =
skalberg@14516
   131
    let
wenzelm@21588
   132
        val cert = cterm_of ProtoPure.thy
wenzelm@21588
   133
        val aT = TFree("'a",[])
wenzelm@21588
   134
        val bT = TFree("'b",[])
wenzelm@21588
   135
        val v = Free("v",aT)
wenzelm@21588
   136
        val P = Free("P",aT-->bT)
wenzelm@21588
   137
        val Q = Free("Q",aT-->bT)
wenzelm@21588
   138
        val cvPQ = cert (list_all ([("v",aT)],Logic.mk_equals(P $ Bound 0,Q $ Bound 0)))
wenzelm@21588
   139
        val cPQ = cert (Logic.mk_equals(P,Q))
wenzelm@21588
   140
        val cv = cert v
wenzelm@21588
   141
        val rew = assume cvPQ
wenzelm@21588
   142
                         |> forall_elim cv
wenzelm@21588
   143
                         |> abstract_rule "v" cv
wenzelm@21588
   144
        val (lhs,rhs) = Logic.dest_equals(concl_of rew)
wenzelm@21588
   145
        val th1 = transitive (transitive
wenzelm@21588
   146
                                  (eta_conversion (cert lhs) |> symmetric)
wenzelm@21588
   147
                                  rew)
wenzelm@21588
   148
                             (eta_conversion (cert rhs))
wenzelm@21588
   149
                             |> implies_intr cvPQ
wenzelm@21588
   150
        val th2 = combination (assume cPQ) (reflexive cv)
wenzelm@21588
   151
                              |> forall_intr cv
wenzelm@21588
   152
                              |> implies_intr cPQ
skalberg@14516
   153
    in
wenzelm@21588
   154
        equal_intr th1 th2 |> standard
skalberg@14516
   155
    end
skalberg@14516
   156
skalberg@14516
   157
val all_comm =
skalberg@14516
   158
    let
wenzelm@21588
   159
        val cert = cterm_of ProtoPure.thy
wenzelm@21588
   160
        val xT = TFree("'a",[])
wenzelm@21588
   161
        val yT = TFree("'b",[])
wenzelm@21588
   162
        val P = Free("P",xT-->yT-->propT)
wenzelm@21588
   163
        val lhs = all xT $ (Abs("x",xT,all yT $ (Abs("y",yT,P $ Bound 1 $ Bound 0))))
wenzelm@21588
   164
        val rhs = all yT $ (Abs("y",yT,all xT $ (Abs("x",xT,P $ Bound 0 $ Bound 1))))
wenzelm@21588
   165
        val cl = cert lhs
wenzelm@21588
   166
        val cr = cert rhs
wenzelm@21588
   167
        val cx = cert (Free("x",xT))
wenzelm@21588
   168
        val cy = cert (Free("y",yT))
wenzelm@21588
   169
        val th1 = assume cr
wenzelm@21588
   170
                         |> forall_elim_list [cy,cx]
wenzelm@21588
   171
                         |> forall_intr_list [cx,cy]
wenzelm@21588
   172
                         |> implies_intr cr
wenzelm@21588
   173
        val th2 = assume cl
wenzelm@21588
   174
                         |> forall_elim_list [cx,cy]
wenzelm@21588
   175
                         |> forall_intr_list [cy,cx]
wenzelm@21588
   176
                         |> implies_intr cl
skalberg@14516
   177
    in
wenzelm@21588
   178
        equal_intr th1 th2 |> standard
skalberg@14516
   179
    end
skalberg@14516
   180
skalberg@14516
   181
val equiv_comm =
skalberg@14516
   182
    let
wenzelm@21588
   183
        val cert = cterm_of ProtoPure.thy
wenzelm@21588
   184
        val T    = TFree("'a",[])
wenzelm@21588
   185
        val t    = Free("t",T)
wenzelm@21588
   186
        val u    = Free("u",T)
wenzelm@21588
   187
        val ctu  = cert (Logic.mk_equals(t,u))
wenzelm@21588
   188
        val cut  = cert (Logic.mk_equals(u,t))
wenzelm@21588
   189
        val th1  = assume ctu |> symmetric |> implies_intr ctu
wenzelm@21588
   190
        val th2  = assume cut |> symmetric |> implies_intr cut
skalberg@14516
   191
    in
wenzelm@21588
   192
        equal_intr th1 th2 |> standard
skalberg@14516
   193
    end
skalberg@14516
   194
skalberg@14516
   195
(* This simplification procedure rewrites !!x y. P x y
skalberg@14516
   196
deterministicly, in order for the normalization function, defined
skalberg@14516
   197
below, to handle nested quantifiers robustly *)
skalberg@14516
   198
skalberg@14516
   199
local
skalberg@14516
   200
skalberg@14516
   201
exception RESULT of int
skalberg@14516
   202
skalberg@14516
   203
fun find_bound n (Bound i) = if i = n then raise RESULT 0
wenzelm@21588
   204
                             else if i = n+1 then raise RESULT 1
wenzelm@21588
   205
                             else ()
skalberg@14516
   206
  | find_bound n (t $ u) = (find_bound n t; find_bound n u)
skalberg@14516
   207
  | find_bound n (Abs(_,_,t)) = find_bound (n+1) t
skalberg@14516
   208
  | find_bound _ _ = ()
skalberg@14516
   209
skalberg@14516
   210
fun swap_bound n (Bound i) = if i = n then Bound (n+1)
wenzelm@21588
   211
                             else if i = n+1 then Bound n
wenzelm@21588
   212
                             else Bound i
skalberg@14516
   213
  | swap_bound n (t $ u) = (swap_bound n t $ swap_bound n u)
skalberg@14516
   214
  | swap_bound n (Abs(x,xT,t)) = Abs(x,xT,swap_bound (n+1) t)
skalberg@14516
   215
  | swap_bound n t = t
skalberg@14516
   216
haftmann@21078
   217
fun rew_th thy (xv as (x,xT)) (yv as (y,yT)) t =
skalberg@14516
   218
    let
wenzelm@21588
   219
        val lhs = list_all ([xv,yv],t)
wenzelm@21588
   220
        val rhs = list_all ([yv,xv],swap_bound 0 t)
wenzelm@21588
   221
        val rew = Logic.mk_equals (lhs,rhs)
wenzelm@21588
   222
        val init = trivial (cterm_of thy rew)
skalberg@14516
   223
    in
wenzelm@21588
   224
        (all_comm RS init handle e => (message "rew_th"; OldGoals.print_exn e))
skalberg@14516
   225
    end
skalberg@14516
   226
haftmann@21078
   227
fun quant_rewrite thy assumes (t as Const("all",T1) $ (Abs(x,xT,Const("all",T2) $ Abs(y,yT,body)))) =
skalberg@14516
   228
    let
wenzelm@21588
   229
        val res = (find_bound 0 body;2) handle RESULT i => i
skalberg@14516
   230
    in
wenzelm@21588
   231
        case res of
wenzelm@21588
   232
            0 => SOME (rew_th thy (x,xT) (y,yT) body)
wenzelm@21588
   233
          | 1 => if string_ord(y,x) = LESS
wenzelm@21588
   234
                 then
wenzelm@21588
   235
                     let
wenzelm@21588
   236
                         val newt = Const("all",T1) $ (Abs(y,xT,Const("all",T2) $ Abs(x,yT,body)))
wenzelm@21588
   237
                         val t_th    = reflexive (cterm_of thy t)
wenzelm@21588
   238
                         val newt_th = reflexive (cterm_of thy newt)
wenzelm@21588
   239
                     in
wenzelm@21588
   240
                         SOME (transitive t_th newt_th)
wenzelm@21588
   241
                     end
wenzelm@21588
   242
                 else NONE
wenzelm@21588
   243
          | _ => error "norm_term (quant_rewrite) internal error"
skalberg@14516
   244
     end
skalberg@15531
   245
  | quant_rewrite _ _ _ = (warning "quant_rewrite: Unknown lhs"; NONE)
skalberg@14516
   246
skalberg@14516
   247
fun freeze_thaw_term t =
skalberg@14516
   248
    let
wenzelm@21588
   249
        val tvars = term_tvars t
wenzelm@21588
   250
        val tfree_names = add_term_tfree_names(t,[])
wenzelm@21588
   251
        val (type_inst,_) =
wenzelm@21588
   252
            Library.foldl (fn ((inst,used),(w as (v,_),S)) =>
wenzelm@21588
   253
                      let
wenzelm@21588
   254
                          val v' = Name.variant used v
wenzelm@21588
   255
                      in
wenzelm@21588
   256
                          ((w,TFree(v',S))::inst,v'::used)
wenzelm@21588
   257
                      end)
wenzelm@21588
   258
                  (([],tfree_names),tvars)
wenzelm@21588
   259
        val t' = subst_TVars type_inst t
skalberg@14516
   260
    in
wenzelm@21588
   261
        (t',map (fn (w,TFree(v,S)) => (v,TVar(w,S))
wenzelm@21588
   262
                  | _ => error "Internal error in Shuffler.freeze_thaw") type_inst)
skalberg@14516
   263
    end
skalberg@14516
   264
haftmann@21078
   265
fun inst_tfrees thy [] thm = thm
wenzelm@21588
   266
  | inst_tfrees thy ((name,U)::rest) thm =
skalberg@14516
   267
    let
wenzelm@21588
   268
        val cU = ctyp_of thy U
wenzelm@21588
   269
        val tfrees = add_term_tfrees (prop_of thm,[])
wenzelm@21588
   270
        val (rens, thm') = Thm.varifyT'
haftmann@20951
   271
    (remove (op = o apsnd fst) name tfrees) thm
wenzelm@21588
   272
        val mid =
wenzelm@21588
   273
            case rens of
wenzelm@21588
   274
                [] => thm'
wenzelm@21588
   275
              | [((_, S), idx)] => instantiate
haftmann@21078
   276
            ([(ctyp_of thy (TVar (idx, S)), cU)], []) thm'
wenzelm@21588
   277
              | _ => error "Shuffler.inst_tfrees internal error"
skalberg@14516
   278
    in
wenzelm@21588
   279
        inst_tfrees thy rest mid
skalberg@14516
   280
    end
skalberg@14516
   281
skalberg@14516
   282
fun is_Abs (Abs _) = true
skalberg@14516
   283
  | is_Abs _ = false
skalberg@14516
   284
skalberg@14516
   285
fun eta_redex (t $ Bound 0) =
skalberg@14516
   286
    let
wenzelm@21588
   287
        fun free n (Bound i) = i = n
wenzelm@21588
   288
          | free n (t $ u) = free n t orelse free n u
wenzelm@21588
   289
          | free n (Abs(_,_,t)) = free (n+1) t
wenzelm@21588
   290
          | free n _ = false
skalberg@14516
   291
    in
wenzelm@21588
   292
        not (free 0 t)
skalberg@14516
   293
    end
skalberg@14516
   294
  | eta_redex _ = false
skalberg@14516
   295
haftmann@21078
   296
fun eta_contract thy assumes origt =
skalberg@14516
   297
    let
wenzelm@21588
   298
        val (typet,Tinst) = freeze_thaw_term origt
wenzelm@21588
   299
        val (init,thaw) = freeze_thaw (reflexive (cterm_of thy typet))
wenzelm@21588
   300
        val final = inst_tfrees thy Tinst o thaw
wenzelm@21588
   301
        val t = #1 (Logic.dest_equals (prop_of init))
wenzelm@21588
   302
        val _ =
wenzelm@21588
   303
            let
wenzelm@21588
   304
                val lhs = #1 (Logic.dest_equals (prop_of (final init)))
wenzelm@21588
   305
            in
wenzelm@21588
   306
                if not (lhs aconv origt)
wenzelm@21588
   307
                then (writeln "Something is utterly wrong: (orig,lhs,frozen type,t,tinst)";
wenzelm@21588
   308
                      writeln (string_of_cterm (cterm_of thy origt));
wenzelm@21588
   309
                      writeln (string_of_cterm (cterm_of thy lhs));
wenzelm@21588
   310
                      writeln (string_of_cterm (cterm_of thy typet));
wenzelm@21588
   311
                      writeln (string_of_cterm (cterm_of thy t));
wenzelm@21588
   312
                      app (fn (n,T) => writeln (n ^ ": " ^ (string_of_ctyp (ctyp_of thy T)))) Tinst;
wenzelm@21588
   313
                      writeln "done")
wenzelm@21588
   314
                else ()
wenzelm@21588
   315
            end
skalberg@14516
   316
    in
wenzelm@21588
   317
        case t of
wenzelm@21588
   318
            Const("all",_) $ (Abs(x,xT,Const("==",eqT) $ P $ Q)) =>
wenzelm@21588
   319
            ((if eta_redex P andalso eta_redex Q
wenzelm@21588
   320
              then
wenzelm@21588
   321
                  let
wenzelm@21588
   322
                      val cert = cterm_of thy
wenzelm@21588
   323
                      val v = Free(Name.variant (add_term_free_names(t,[])) "v",xT)
wenzelm@21588
   324
                      val cv = cert v
wenzelm@21588
   325
                      val ct = cert t
wenzelm@21588
   326
                      val th = (assume ct)
wenzelm@21588
   327
                                   |> forall_elim cv
wenzelm@21588
   328
                                   |> abstract_rule x cv
wenzelm@21588
   329
                      val ext_th = eta_conversion (cert (Abs(x,xT,P)))
wenzelm@21588
   330
                      val th' = transitive (symmetric ext_th) th
wenzelm@21588
   331
                      val cu = cert (prop_of th')
wenzelm@21588
   332
                      val uth = combination (assume cu) (reflexive cv)
wenzelm@21588
   333
                      val uth' = (beta_conversion false (cert (Abs(x,xT,Q) $ v)))
wenzelm@21588
   334
                                     |> transitive uth
wenzelm@21588
   335
                                     |> forall_intr cv
wenzelm@21588
   336
                                     |> implies_intr cu
wenzelm@21588
   337
                      val rew_th = equal_intr (th' |> implies_intr ct) uth'
wenzelm@21588
   338
                      val res = final rew_th
wenzelm@21588
   339
                      val lhs = (#1 (Logic.dest_equals (prop_of res)))
wenzelm@21588
   340
                  in
wenzelm@21588
   341
                       SOME res
wenzelm@21588
   342
                  end
wenzelm@21588
   343
              else NONE)
wenzelm@21588
   344
             handle e => OldGoals.print_exn e)
wenzelm@21588
   345
          | _ => NONE
obua@17440
   346
       end
skalberg@14516
   347
haftmann@21078
   348
fun beta_fun thy assume t =
haftmann@21078
   349
    SOME (beta_conversion true (cterm_of thy t))
skalberg@14516
   350
obua@17188
   351
val meta_sym_rew = thm "refl"
obua@17188
   352
haftmann@21078
   353
fun equals_fun thy assume t =
obua@17188
   354
    case t of
wenzelm@21588
   355
        Const("op ==",_) $ u $ v => if Term.term_ord (u,v) = LESS then SOME (meta_sym_rew) else NONE
obua@17188
   356
      | _ => NONE
obua@17188
   357
haftmann@21078
   358
fun eta_expand thy assumes origt =
skalberg@14516
   359
    let
wenzelm@21588
   360
        val (typet,Tinst) = freeze_thaw_term origt
wenzelm@21588
   361
        val (init,thaw) = freeze_thaw (reflexive (cterm_of thy typet))
wenzelm@21588
   362
        val final = inst_tfrees thy Tinst o thaw
wenzelm@21588
   363
        val t = #1 (Logic.dest_equals (prop_of init))
wenzelm@21588
   364
        val _ =
wenzelm@21588
   365
            let
wenzelm@21588
   366
                val lhs = #1 (Logic.dest_equals (prop_of (final init)))
wenzelm@21588
   367
            in
wenzelm@21588
   368
                if not (lhs aconv origt)
wenzelm@21588
   369
                then (writeln "Something is utterly wrong: (orig,lhs,frozen type,t,tinst)";
wenzelm@21588
   370
                      writeln (string_of_cterm (cterm_of thy origt));
wenzelm@21588
   371
                      writeln (string_of_cterm (cterm_of thy lhs));
wenzelm@21588
   372
                      writeln (string_of_cterm (cterm_of thy typet));
wenzelm@21588
   373
                      writeln (string_of_cterm (cterm_of thy t));
wenzelm@21588
   374
                      app (fn (n,T) => writeln (n ^ ": " ^ (string_of_ctyp (ctyp_of thy T)))) Tinst;
wenzelm@21588
   375
                      writeln "done")
wenzelm@21588
   376
                else ()
wenzelm@21588
   377
            end
skalberg@14516
   378
    in
wenzelm@21588
   379
        case t of
wenzelm@21588
   380
            Const("==",T) $ P $ Q =>
wenzelm@21588
   381
            if is_Abs P orelse is_Abs Q
wenzelm@21588
   382
            then (case domain_type T of
wenzelm@21588
   383
                      Type("fun",[aT,bT]) =>
wenzelm@21588
   384
                      let
wenzelm@21588
   385
                          val cert = cterm_of thy
wenzelm@21588
   386
                          val vname = Name.variant (add_term_free_names(t,[])) "v"
wenzelm@21588
   387
                          val v = Free(vname,aT)
wenzelm@21588
   388
                          val cv = cert v
wenzelm@21588
   389
                          val ct = cert t
wenzelm@21588
   390
                          val th1 = (combination (assume ct) (reflexive cv))
wenzelm@21588
   391
                                        |> forall_intr cv
wenzelm@21588
   392
                                        |> implies_intr ct
wenzelm@21588
   393
                          val concl = cert (concl_of th1)
wenzelm@21588
   394
                          val th2 = (assume concl)
wenzelm@21588
   395
                                        |> forall_elim cv
wenzelm@21588
   396
                                        |> abstract_rule vname cv
wenzelm@21588
   397
                          val (lhs,rhs) = Logic.dest_equals (prop_of th2)
wenzelm@21588
   398
                          val elhs = eta_conversion (cert lhs)
wenzelm@21588
   399
                          val erhs = eta_conversion (cert rhs)
wenzelm@21588
   400
                          val th2' = transitive
wenzelm@21588
   401
                                         (transitive (symmetric elhs) th2)
wenzelm@21588
   402
                                         erhs
wenzelm@21588
   403
                          val res = equal_intr th1 (th2' |> implies_intr concl)
wenzelm@21588
   404
                          val res' = final res
wenzelm@21588
   405
                      in
wenzelm@21588
   406
                          SOME res'
wenzelm@21588
   407
                      end
wenzelm@21588
   408
                    | _ => NONE)
wenzelm@21588
   409
            else NONE
wenzelm@21588
   410
          | _ => (error ("Bad eta_expand argument" ^ (string_of_cterm (cterm_of thy t))); NONE)
skalberg@14516
   411
    end
wenzelm@17959
   412
    handle e => (writeln "eta_expand internal error"; OldGoals.print_exn e)
skalberg@14516
   413
wenzelm@14854
   414
fun mk_tfree s = TFree("'"^s,[])
obua@20326
   415
fun mk_free s t = Free (s,t)
skalberg@14516
   416
val xT = mk_tfree "a"
skalberg@14516
   417
val yT = mk_tfree "b"
obua@20326
   418
val P  = mk_free "P" (xT-->yT-->propT)
obua@20326
   419
val Q  = mk_free "Q" (xT-->yT)
obua@20326
   420
val R  = mk_free "R" (xT-->yT)
obua@20326
   421
val S  = mk_free "S" xT
obua@20326
   422
val S'  = mk_free "S'" xT
skalberg@14516
   423
in
haftmann@21078
   424
fun beta_simproc thy = Simplifier.simproc_i
wenzelm@21588
   425
                      thy
wenzelm@21588
   426
                      "Beta-contraction"
wenzelm@21588
   427
                      [Abs("x",xT,Q) $ S]
wenzelm@21588
   428
                      beta_fun
skalberg@14516
   429
haftmann@21078
   430
fun equals_simproc thy = Simplifier.simproc_i
wenzelm@21588
   431
                      thy
wenzelm@21588
   432
                      "Ordered rewriting of meta equalities"
wenzelm@21588
   433
                      [Const("op ==",xT) $ S $ S']
wenzelm@21588
   434
                      equals_fun
obua@17188
   435
haftmann@21078
   436
fun quant_simproc thy = Simplifier.simproc_i
wenzelm@21588
   437
                           thy
wenzelm@21588
   438
                           "Ordered rewriting of nested quantifiers"
wenzelm@21588
   439
                           [all xT $ (Abs("x",xT,all yT $ (Abs("y",yT,P $ Bound 1 $ Bound 0))))]
wenzelm@21588
   440
                           quant_rewrite
haftmann@21078
   441
fun eta_expand_simproc thy = Simplifier.simproc_i
wenzelm@21588
   442
                         thy
wenzelm@21588
   443
                         "Smart eta-expansion by equivalences"
wenzelm@21588
   444
                         [Logic.mk_equals(Q,R)]
wenzelm@21588
   445
                         eta_expand
haftmann@21078
   446
fun eta_contract_simproc thy = Simplifier.simproc_i
wenzelm@21588
   447
                         thy
wenzelm@21588
   448
                         "Smart handling of eta-contractions"
wenzelm@21588
   449
                         [all xT $ (Abs("x",xT,Logic.mk_equals(Q $ Bound 0,R $ Bound 0)))]
wenzelm@21588
   450
                         eta_contract
skalberg@14516
   451
end
skalberg@14516
   452
skalberg@14516
   453
(* Disambiguates the names of bound variables in a term, returning t
skalberg@14516
   454
== t' where all the names of bound variables in t' are unique *)
skalberg@14516
   455
haftmann@21078
   456
fun disamb_bound thy t =
skalberg@14516
   457
    let
wenzelm@21588
   458
wenzelm@21588
   459
        fun F (t $ u,idx) =
wenzelm@21588
   460
            let
wenzelm@21588
   461
                val (t',idx') = F (t,idx)
wenzelm@21588
   462
                val (u',idx'') = F (u,idx')
wenzelm@21588
   463
            in
wenzelm@21588
   464
                (t' $ u',idx'')
wenzelm@21588
   465
            end
wenzelm@21588
   466
          | F (Abs(x,xT,t),idx) =
wenzelm@21588
   467
            let
wenzelm@21588
   468
                val x' = "x" ^ (LargeInt.toString idx) (* amazing *)
wenzelm@21588
   469
                val (t',idx') = F (t,idx+1)
wenzelm@21588
   470
            in
wenzelm@21588
   471
                (Abs(x',xT,t'),idx')
wenzelm@21588
   472
            end
wenzelm@21588
   473
          | F arg = arg
wenzelm@21588
   474
        val (t',_) = F (t,0)
wenzelm@21588
   475
        val ct = cterm_of thy t
wenzelm@21588
   476
        val ct' = cterm_of thy t'
wenzelm@21588
   477
        val res = transitive (reflexive ct) (reflexive ct')
wenzelm@21588
   478
        val _ = message ("disamb_term: " ^ (string_of_thm res))
skalberg@14516
   479
    in
wenzelm@21588
   480
        res
skalberg@14516
   481
    end
skalberg@14516
   482
skalberg@14516
   483
(* Transforms a term t to some normal form t', returning the theorem t
skalberg@14516
   484
== t'.  This is originally a help function for make_equal, but might
skalberg@14516
   485
be handy in its own right, for example for indexing terms. *)
skalberg@14516
   486
skalberg@14516
   487
fun norm_term thy t =
skalberg@14516
   488
    let
wenzelm@21588
   489
        val norms = ShuffleData.get thy
wenzelm@21588
   490
        val ss = Simplifier.theory_context thy empty_ss
wenzelm@17892
   491
          setmksimps single
wenzelm@21588
   492
          addsimps (map (Thm.transfer thy) norms)
haftmann@21078
   493
          addsimprocs [quant_simproc thy, eta_expand_simproc thy,eta_contract_simproc thy]
wenzelm@21588
   494
        fun chain f th =
wenzelm@21588
   495
            let
wenzelm@22902
   496
                val rhs = Thm.rhs_of th
wenzelm@21588
   497
            in
wenzelm@21588
   498
                transitive th (f rhs)
wenzelm@21588
   499
            end
wenzelm@21588
   500
        val th =
haftmann@21078
   501
            t |> disamb_bound thy
wenzelm@21588
   502
              |> chain (Simplifier.full_rewrite ss)
obua@20326
   503
              |> chain eta_conversion
wenzelm@21588
   504
              |> strip_shyps
wenzelm@21588
   505
        val _ = message ("norm_term: " ^ (string_of_thm th))
skalberg@14516
   506
    in
wenzelm@21588
   507
        th
wenzelm@17463
   508
    end
haftmann@21078
   509
    handle e => (writeln "norm_term internal error"; print_sign_exn thy e)
skalberg@14516
   510
skalberg@14516
   511
skalberg@14516
   512
(* Closes a theorem with respect to free and schematic variables (does
skalberg@14516
   513
not touch type variables, though). *)
skalberg@14516
   514
skalberg@14516
   515
fun close_thm th =
skalberg@14516
   516
    let
wenzelm@22578
   517
        val thy = Thm.theory_of_thm th
wenzelm@21588
   518
        val c = prop_of th
wenzelm@21588
   519
        val vars = add_term_frees (c,add_term_vars(c,[]))
skalberg@14516
   520
    in
wenzelm@21588
   521
        Drule.forall_intr_list (map (cterm_of thy) vars) th
skalberg@14516
   522
    end
wenzelm@17959
   523
    handle e => (writeln "close_thm internal error"; OldGoals.print_exn e)
skalberg@14516
   524
skalberg@14516
   525
(* Normalizes a theorem's conclusion using norm_term. *)
skalberg@14516
   526
skalberg@14516
   527
fun norm_thm thy th =
skalberg@14516
   528
    let
wenzelm@21588
   529
        val c = prop_of th
skalberg@14516
   530
    in
wenzelm@21588
   531
        equal_elim (norm_term thy c) th
skalberg@14516
   532
    end
skalberg@14516
   533
haftmann@21078
   534
(* make_equal thy t u tries to construct the theorem t == u under the
haftmann@21078
   535
signature thy.  If it succeeds, SOME (t == u) is returned, otherwise
skalberg@15531
   536
NONE is returned. *)
skalberg@14516
   537
haftmann@21078
   538
fun make_equal thy t u =
skalberg@14516
   539
    let
wenzelm@21588
   540
        val t_is_t' = norm_term thy t
wenzelm@21588
   541
        val u_is_u' = norm_term thy u
wenzelm@21588
   542
        val th = transitive t_is_t' (symmetric u_is_u')
wenzelm@21588
   543
        val _ = message ("make_equal: SOME " ^ (string_of_thm th))
skalberg@14516
   544
    in
wenzelm@21588
   545
        SOME th
skalberg@14516
   546
    end
skalberg@15531
   547
    handle e as THM _ => (message "make_equal: NONE";NONE)
wenzelm@21588
   548
skalberg@14516
   549
fun match_consts ignore t (* th *) =
skalberg@14516
   550
    let
wenzelm@21588
   551
        fun add_consts (Const (c, _), cs) =
wenzelm@21588
   552
            if c mem_string ignore
wenzelm@21588
   553
            then cs
wenzelm@21588
   554
            else insert (op =) c cs
wenzelm@21588
   555
          | add_consts (t $ u, cs) = add_consts (t, add_consts (u, cs))
wenzelm@21588
   556
          | add_consts (Abs (_, _, t), cs) = add_consts (t, cs)
wenzelm@21588
   557
          | add_consts (_, cs) = cs
wenzelm@21588
   558
        val t_consts = add_consts(t,[])
skalberg@14516
   559
    in
skalberg@14516
   560
     fn (name,th) =>
wenzelm@21588
   561
        let
wenzelm@21588
   562
            val th_consts = add_consts(prop_of th,[])
wenzelm@21588
   563
        in
wenzelm@21588
   564
            eq_set(t_consts,th_consts)
wenzelm@21588
   565
        end
skalberg@14516
   566
    end
wenzelm@21588
   567
skalberg@14516
   568
val collect_ignored =
haftmann@21078
   569
    fold_rev (fn thm => fn cs =>
wenzelm@21588
   570
              let
wenzelm@21588
   571
                  val (lhs,rhs) = Logic.dest_equals (prop_of thm)
wenzelm@21588
   572
                  val ignore_lhs = term_consts lhs \\ term_consts rhs
wenzelm@21588
   573
                  val ignore_rhs = term_consts rhs \\ term_consts lhs
wenzelm@21588
   574
              in
wenzelm@21588
   575
                  fold_rev (insert (op =)) cs (ignore_lhs @ ignore_rhs)
wenzelm@21588
   576
              end)
skalberg@14516
   577
skalberg@14516
   578
(* set_prop t thms tries to make a theorem with the proposition t from
skalberg@14516
   579
one of the theorems thms, by shuffling the propositions around.  If it
skalberg@15531
   580
succeeds, SOME theorem is returned, otherwise NONE.  *)
skalberg@14516
   581
skalberg@14516
   582
fun set_prop thy t =
skalberg@14516
   583
    let
wenzelm@21588
   584
        val vars = add_term_frees (t,add_term_vars (t,[]))
wenzelm@21588
   585
        val closed_t = Library.foldr (fn (v, body) =>
haftmann@21078
   586
      let val vT = type_of v in all vT $ (Abs ("x", vT, abstract_over (v, body))) end) (vars, t)
wenzelm@21588
   587
        val rew_th = norm_term thy closed_t
wenzelm@22902
   588
        val rhs = Thm.rhs_of rew_th
skalberg@14516
   589
wenzelm@21588
   590
        val shuffles = ShuffleData.get thy
wenzelm@21588
   591
        fun process [] = NONE
wenzelm@21588
   592
          | process ((name,th)::thms) =
wenzelm@21588
   593
            let
wenzelm@21588
   594
                val norm_th = Thm.varifyT (norm_thm thy (close_thm (Thm.transfer thy th)))
wenzelm@21588
   595
                val triv_th = trivial rhs
wenzelm@21588
   596
                val _ = message ("Shuffler.set_prop: Gluing together " ^ (string_of_thm norm_th) ^ " and " ^ (string_of_thm triv_th))
wenzelm@21588
   597
                val mod_th = case Seq.pull (bicompose false (*true*) (false,norm_th,0) 1 triv_th) of
wenzelm@21588
   598
                                 SOME(th,_) => SOME th
wenzelm@21588
   599
                               | NONE => NONE
wenzelm@21588
   600
            in
wenzelm@21588
   601
                case mod_th of
wenzelm@21588
   602
                    SOME mod_th =>
wenzelm@21588
   603
                    let
wenzelm@21588
   604
                        val closed_th = equal_elim (symmetric rew_th) mod_th
wenzelm@21588
   605
                    in
wenzelm@21588
   606
                        message ("Shuffler.set_prop succeeded by " ^ name);
wenzelm@21588
   607
                        SOME (name,forall_elim_list (map (cterm_of thy) vars) closed_th)
wenzelm@21588
   608
                    end
wenzelm@21588
   609
                  | NONE => process thms
wenzelm@21588
   610
            end
wenzelm@21588
   611
            handle e as THM _ => process thms
skalberg@14516
   612
    in
wenzelm@21588
   613
        fn thms =>
wenzelm@21588
   614
           case process thms of
wenzelm@21588
   615
               res as SOME (name,th) => if (prop_of th) aconv t
wenzelm@21588
   616
                                        then res
wenzelm@21588
   617
                                        else error "Internal error in set_prop"
wenzelm@21588
   618
             | NONE => NONE
skalberg@14516
   619
    end
wenzelm@17959
   620
    handle e => (writeln "set_prop internal error"; OldGoals.print_exn e)
skalberg@14516
   621
skalberg@14516
   622
fun find_potential thy t =
skalberg@14516
   623
    let
wenzelm@21588
   624
        val shuffles = ShuffleData.get thy
wenzelm@21588
   625
        val ignored = collect_ignored shuffles []
wenzelm@26277
   626
        val all_thms = map (`PureThy.get_name_hint) (maps #2 (Facts.dest (PureThy.all_facts_of thy)))
skalberg@14516
   627
    in
wenzelm@26277
   628
        List.filter (match_consts ignored t) all_thms
skalberg@14516
   629
    end
skalberg@14516
   630
skalberg@14516
   631
fun gen_shuffle_tac thy search thms i st =
skalberg@14516
   632
    let
wenzelm@21588
   633
        val _ = message ("Shuffling " ^ (string_of_thm st))
wenzelm@21588
   634
        val t = List.nth(prems_of st,i-1)
wenzelm@21588
   635
        val set = set_prop thy t
wenzelm@21588
   636
        fun process_tac thms st =
wenzelm@21588
   637
            case set thms of
wenzelm@21588
   638
                SOME (_,th) => Seq.of_list (compose (th,i,st))
wenzelm@21588
   639
              | NONE => Seq.empty
skalberg@14516
   640
    in
wenzelm@21588
   641
        (process_tac thms APPEND (if search
wenzelm@21588
   642
                                  then process_tac (find_potential thy t)
wenzelm@21588
   643
                                  else no_tac)) st
skalberg@14516
   644
    end
skalberg@14516
   645
skalberg@14516
   646
fun shuffle_tac thms i st =
skalberg@14516
   647
    gen_shuffle_tac (the_context()) false thms i st
skalberg@14516
   648
skalberg@14516
   649
fun search_tac thms i st =
skalberg@14516
   650
    gen_shuffle_tac (the_context()) true thms i st
skalberg@14516
   651
skalberg@14516
   652
fun shuffle_meth (thms:thm list) ctxt =
skalberg@14516
   653
    let
wenzelm@21588
   654
        val thy = ProofContext.theory_of ctxt
skalberg@14516
   655
    in
wenzelm@21588
   656
        Method.SIMPLE_METHOD' (gen_shuffle_tac thy false (map (pair "") thms))
skalberg@14516
   657
    end
skalberg@14516
   658
skalberg@14516
   659
fun search_meth ctxt =
skalberg@14516
   660
    let
wenzelm@21588
   661
        val thy = ProofContext.theory_of ctxt
wenzelm@21588
   662
        val prems = Assumption.prems_of ctxt
skalberg@14516
   663
    in
wenzelm@21588
   664
        Method.SIMPLE_METHOD' (gen_shuffle_tac thy true (map (pair "premise") prems))
skalberg@14516
   665
    end
skalberg@14516
   666
skalberg@14516
   667
fun add_shuffle_rule thm thy =
skalberg@14516
   668
    let
wenzelm@21588
   669
        val shuffles = ShuffleData.get thy
skalberg@14516
   670
    in
wenzelm@21588
   671
        if exists (curry Thm.eq_thm thm) shuffles
wenzelm@21588
   672
        then (warning ((string_of_thm thm) ^ " already known to the shuffler");
wenzelm@21588
   673
              thy)
wenzelm@21588
   674
        else ShuffleData.put (thm::shuffles) thy
skalberg@14516
   675
    end
skalberg@14516
   676
wenzelm@20897
   677
val shuffle_attr = Thm.declaration_attribute (fn th => Context.mapping (add_shuffle_rule th) I);
skalberg@14516
   678
wenzelm@18708
   679
val setup =
wenzelm@22846
   680
  Method.add_method ("shuffle_tac",
wenzelm@22846
   681
    Method.thms_ctxt_args shuffle_meth,"solve goal by shuffling terms around") #>
wenzelm@22846
   682
  Method.add_method ("search_tac",
wenzelm@22846
   683
    Method.ctxt_args search_meth,"search for suitable theorems") #>
wenzelm@18708
   684
  add_shuffle_rule weaken #>
wenzelm@18708
   685
  add_shuffle_rule equiv_comm #>
wenzelm@18708
   686
  add_shuffle_rule imp_comm #>
wenzelm@18708
   687
  add_shuffle_rule Drule.norm_hhf_eq #>
wenzelm@18708
   688
  add_shuffle_rule Drule.triv_forall_equality #>
wenzelm@18728
   689
  Attrib.add_attributes [("shuffle_rule", Attrib.no_args shuffle_attr, "declare rule for shuffler")]
wenzelm@18708
   690
skalberg@14516
   691
end