src/HOL/Import/shuffler.ML
author nipkow
Mon Jan 30 21:49:41 2012 +0100 (2012-01-30)
changeset 46372 6fa9cdb8b850
parent 46218 ecf6375e2abb
child 46803 f8875c15cbe1
permissions -rw-r--r--
added "'a rel"
skalberg@14620
     1
(*  Title:      HOL/Import/shuffler.ML
skalberg@14516
     2
    Author:     Sebastian Skalberg, TU Muenchen
skalberg@14516
     3
skalberg@14516
     4
Package for proving two terms equal by normalizing (hence the
skalberg@14516
     5
"shuffler" name).  Uses the simplifier for the normalization.
skalberg@14516
     6
*)
skalberg@14516
     7
skalberg@14516
     8
signature Shuffler =
skalberg@14516
     9
sig
wenzelm@32740
    10
    val debug      : bool Unsynchronized.ref
skalberg@14516
    11
skalberg@14516
    12
    val norm_term  : theory -> term -> thm
skalberg@14516
    13
    val make_equal : theory -> term -> term -> thm option
skalberg@14516
    14
    val set_prop   : theory -> term -> (string * thm) list -> (string * thm) option
skalberg@14516
    15
skalberg@14516
    16
    val find_potential: theory -> term -> (string * thm) list
skalberg@14516
    17
wenzelm@31241
    18
    val gen_shuffle_tac: Proof.context -> bool -> (string * thm) list -> int -> tactic
wenzelm@31241
    19
    val shuffle_tac: Proof.context -> thm list -> int -> tactic
wenzelm@31241
    20
    val search_tac : Proof.context -> int -> tactic
skalberg@14516
    21
skalberg@14516
    22
    val print_shuffles: theory -> unit
skalberg@14516
    23
skalberg@14516
    24
    val add_shuffle_rule: thm -> theory -> theory
wenzelm@18728
    25
    val shuffle_attr: attribute
skalberg@14516
    26
wenzelm@18708
    27
    val setup      : theory -> theory
skalberg@14516
    28
end
skalberg@14516
    29
skalberg@14516
    30
structure Shuffler :> Shuffler =
skalberg@14516
    31
struct
skalberg@14516
    32
wenzelm@32740
    33
val debug = Unsynchronized.ref false
skalberg@14516
    34
skalberg@14516
    35
fun if_debug f x = if !debug then f x else ()
skalberg@14516
    36
val message = if_debug writeln
skalberg@14516
    37
wenzelm@37146
    38
val string_of_thm = Print_Mode.setmp [] Display.string_of_thm_without_context;
skalberg@14516
    39
wenzelm@33522
    40
structure ShuffleData = Theory_Data
wenzelm@22846
    41
(
wenzelm@22846
    42
  type T = thm list
wenzelm@22846
    43
  val empty = []
wenzelm@22846
    44
  val extend = I
wenzelm@33522
    45
  val merge = Thm.merge_thms
wenzelm@22846
    46
)
skalberg@14516
    47
wenzelm@22846
    48
fun print_shuffles thy =
wenzelm@22846
    49
  Pretty.writeln (Pretty.big_list "Shuffle theorems:"
wenzelm@32091
    50
    (map (Display.pretty_thm_global thy) (ShuffleData.get thy)))
skalberg@14516
    51
skalberg@14516
    52
val weaken =
skalberg@14516
    53
    let
wenzelm@26424
    54
        val cert = cterm_of Pure.thy
wenzelm@21588
    55
        val P = Free("P",propT)
wenzelm@21588
    56
        val Q = Free("Q",propT)
wenzelm@21588
    57
        val PQ = Logic.mk_implies(P,Q)
wenzelm@21588
    58
        val PPQ = Logic.mk_implies(P,PQ)
wenzelm@21588
    59
        val cP = cert P
wenzelm@21588
    60
        val cQ = cert Q
wenzelm@21588
    61
        val cPQ = cert PQ
wenzelm@21588
    62
        val cPPQ = cert PPQ
wenzelm@36945
    63
        val th1 = Thm.assume cPQ |> implies_intr_list [cPQ,cP]
wenzelm@36945
    64
        val th3 = Thm.assume cP
wenzelm@36945
    65
        val th4 = implies_elim_list (Thm.assume cPPQ) [th3,th3]
wenzelm@21588
    66
                                    |> implies_intr_list [cPPQ,cP]
skalberg@14516
    67
    in
wenzelm@36945
    68
        Thm.equal_intr th4 th1 |> Drule.export_without_context
skalberg@14516
    69
    end
skalberg@14516
    70
skalberg@14516
    71
val imp_comm =
skalberg@14516
    72
    let
wenzelm@26424
    73
        val cert = cterm_of Pure.thy
wenzelm@21588
    74
        val P = Free("P",propT)
wenzelm@21588
    75
        val Q = Free("Q",propT)
wenzelm@21588
    76
        val R = Free("R",propT)
wenzelm@21588
    77
        val PQR = Logic.mk_implies(P,Logic.mk_implies(Q,R))
wenzelm@21588
    78
        val QPR = Logic.mk_implies(Q,Logic.mk_implies(P,R))
wenzelm@21588
    79
        val cP = cert P
wenzelm@21588
    80
        val cQ = cert Q
wenzelm@21588
    81
        val cPQR = cert PQR
wenzelm@21588
    82
        val cQPR = cert QPR
wenzelm@36945
    83
        val th1 = implies_elim_list (Thm.assume cPQR) [Thm.assume cP,Thm.assume cQ]
wenzelm@21588
    84
                                    |> implies_intr_list [cPQR,cQ,cP]
wenzelm@36945
    85
        val th2 = implies_elim_list (Thm.assume cQPR) [Thm.assume cQ,Thm.assume cP]
wenzelm@21588
    86
                                    |> implies_intr_list [cQPR,cP,cQ]
skalberg@14516
    87
    in
wenzelm@36945
    88
        Thm.equal_intr th1 th2 |> Drule.export_without_context
skalberg@14516
    89
    end
skalberg@14516
    90
skalberg@14516
    91
val all_comm =
skalberg@14516
    92
    let
wenzelm@26424
    93
        val cert = cterm_of Pure.thy
wenzelm@21588
    94
        val xT = TFree("'a",[])
wenzelm@21588
    95
        val yT = TFree("'b",[])
wenzelm@27330
    96
        val x = Free("x",xT)
wenzelm@27330
    97
        val y = Free("y",yT)
wenzelm@21588
    98
        val P = Free("P",xT-->yT-->propT)
wenzelm@27330
    99
        val lhs = Logic.all x (Logic.all y (P $ x $ y))
wenzelm@27330
   100
        val rhs = Logic.all y (Logic.all x (P $ x $ y))
wenzelm@21588
   101
        val cl = cert lhs
wenzelm@21588
   102
        val cr = cert rhs
wenzelm@27330
   103
        val cx = cert x
wenzelm@27330
   104
        val cy = cert y
wenzelm@36945
   105
        val th1 = Thm.assume cr
wenzelm@21588
   106
                         |> forall_elim_list [cy,cx]
wenzelm@21588
   107
                         |> forall_intr_list [cx,cy]
wenzelm@36945
   108
                         |> Thm.implies_intr cr
wenzelm@36945
   109
        val th2 = Thm.assume cl
wenzelm@21588
   110
                         |> forall_elim_list [cx,cy]
wenzelm@21588
   111
                         |> forall_intr_list [cy,cx]
wenzelm@36945
   112
                         |> Thm.implies_intr cl
skalberg@14516
   113
    in
wenzelm@36945
   114
        Thm.equal_intr th1 th2 |> Drule.export_without_context
skalberg@14516
   115
    end
skalberg@14516
   116
skalberg@14516
   117
val equiv_comm =
skalberg@14516
   118
    let
wenzelm@26424
   119
        val cert = cterm_of Pure.thy
wenzelm@21588
   120
        val T    = TFree("'a",[])
wenzelm@21588
   121
        val t    = Free("t",T)
wenzelm@21588
   122
        val u    = Free("u",T)
wenzelm@21588
   123
        val ctu  = cert (Logic.mk_equals(t,u))
wenzelm@21588
   124
        val cut  = cert (Logic.mk_equals(u,t))
wenzelm@36945
   125
        val th1  = Thm.assume ctu |> Thm.symmetric |> Thm.implies_intr ctu
wenzelm@36945
   126
        val th2  = Thm.assume cut |> Thm.symmetric |> Thm.implies_intr cut
skalberg@14516
   127
    in
wenzelm@36945
   128
        Thm.equal_intr th1 th2 |> Drule.export_without_context
skalberg@14516
   129
    end
skalberg@14516
   130
skalberg@14516
   131
(* This simplification procedure rewrites !!x y. P x y
skalberg@14516
   132
deterministicly, in order for the normalization function, defined
skalberg@14516
   133
below, to handle nested quantifiers robustly *)
skalberg@14516
   134
skalberg@14516
   135
local
skalberg@14516
   136
skalberg@14516
   137
exception RESULT of int
skalberg@14516
   138
skalberg@14516
   139
fun find_bound n (Bound i) = if i = n then raise RESULT 0
wenzelm@21588
   140
                             else if i = n+1 then raise RESULT 1
wenzelm@21588
   141
                             else ()
skalberg@14516
   142
  | find_bound n (t $ u) = (find_bound n t; find_bound n u)
skalberg@14516
   143
  | find_bound n (Abs(_,_,t)) = find_bound (n+1) t
skalberg@14516
   144
  | find_bound _ _ = ()
skalberg@14516
   145
skalberg@14516
   146
fun swap_bound n (Bound i) = if i = n then Bound (n+1)
wenzelm@21588
   147
                             else if i = n+1 then Bound n
wenzelm@21588
   148
                             else Bound i
skalberg@14516
   149
  | swap_bound n (t $ u) = (swap_bound n t $ swap_bound n u)
skalberg@14516
   150
  | swap_bound n (Abs(x,xT,t)) = Abs(x,xT,swap_bound (n+1) t)
skalberg@14516
   151
  | swap_bound n t = t
skalberg@14516
   152
haftmann@21078
   153
fun rew_th thy (xv as (x,xT)) (yv as (y,yT)) t =
skalberg@14516
   154
    let
wenzelm@46218
   155
        val lhs = Logic.list_all ([xv,yv],t)
wenzelm@46218
   156
        val rhs = Logic.list_all ([yv,xv],swap_bound 0 t)
wenzelm@21588
   157
        val rew = Logic.mk_equals (lhs,rhs)
wenzelm@36945
   158
        val init = Thm.trivial (cterm_of thy rew)
skalberg@14516
   159
    in
wenzelm@37778
   160
        all_comm RS init
skalberg@14516
   161
    end
skalberg@14516
   162
wenzelm@46201
   163
fun quant_rewrite thy _ (t as Const("all",T1) $ (Abs(x,xT,Const("all",T2) $ Abs(y,yT,body)))) =
skalberg@14516
   164
    let
wenzelm@21588
   165
        val res = (find_bound 0 body;2) handle RESULT i => i
skalberg@14516
   166
    in
wenzelm@21588
   167
        case res of
wenzelm@21588
   168
            0 => SOME (rew_th thy (x,xT) (y,yT) body)
wenzelm@21588
   169
          | 1 => if string_ord(y,x) = LESS
wenzelm@21588
   170
                 then
wenzelm@21588
   171
                     let
wenzelm@21588
   172
                         val newt = Const("all",T1) $ (Abs(y,xT,Const("all",T2) $ Abs(x,yT,body)))
wenzelm@36945
   173
                         val t_th    = Thm.reflexive (cterm_of thy t)
wenzelm@36945
   174
                         val newt_th = Thm.reflexive (cterm_of thy newt)
wenzelm@21588
   175
                     in
wenzelm@36945
   176
                         SOME (Thm.transitive t_th newt_th)
wenzelm@21588
   177
                     end
wenzelm@21588
   178
                 else NONE
wenzelm@21588
   179
          | _ => error "norm_term (quant_rewrite) internal error"
skalberg@14516
   180
     end
skalberg@15531
   181
  | quant_rewrite _ _ _ = (warning "quant_rewrite: Unknown lhs"; NONE)
skalberg@14516
   182
skalberg@14516
   183
fun freeze_thaw_term t =
skalberg@14516
   184
    let
wenzelm@44121
   185
        val tvars = Misc_Legacy.term_tvars t
wenzelm@44121
   186
        val tfree_names = Misc_Legacy.add_term_tfree_names(t,[])
wenzelm@21588
   187
        val (type_inst,_) =
wenzelm@33245
   188
            fold (fn (w as (v,_), S) => fn (inst, used) =>
wenzelm@21588
   189
                      let
wenzelm@43324
   190
                          val v' = singleton (Name.variant_list used) v
wenzelm@21588
   191
                      in
wenzelm@21588
   192
                          ((w,TFree(v',S))::inst,v'::used)
wenzelm@21588
   193
                      end)
wenzelm@33245
   194
                  tvars ([], tfree_names)
wenzelm@21588
   195
        val t' = subst_TVars type_inst t
skalberg@14516
   196
    in
wenzelm@33245
   197
        (t', map (fn (w,TFree(v,S)) => (v,TVar(w,S))
wenzelm@21588
   198
                  | _ => error "Internal error in Shuffler.freeze_thaw") type_inst)
skalberg@14516
   199
    end
skalberg@14516
   200
haftmann@21078
   201
fun inst_tfrees thy [] thm = thm
wenzelm@21588
   202
  | inst_tfrees thy ((name,U)::rest) thm =
skalberg@14516
   203
    let
wenzelm@21588
   204
        val cU = ctyp_of thy U
wenzelm@44121
   205
        val tfrees = Misc_Legacy.add_term_tfrees (prop_of thm,[])
wenzelm@35845
   206
        val (rens, thm') = Thm.varifyT_global'
haftmann@20951
   207
    (remove (op = o apsnd fst) name tfrees) thm
wenzelm@21588
   208
        val mid =
wenzelm@21588
   209
            case rens of
wenzelm@21588
   210
                [] => thm'
wenzelm@43333
   211
              | [((_, S), idx)] => Drule.instantiate_normalize
haftmann@21078
   212
            ([(ctyp_of thy (TVar (idx, S)), cU)], []) thm'
wenzelm@21588
   213
              | _ => error "Shuffler.inst_tfrees internal error"
skalberg@14516
   214
    in
wenzelm@21588
   215
        inst_tfrees thy rest mid
skalberg@14516
   216
    end
skalberg@14516
   217
skalberg@14516
   218
fun is_Abs (Abs _) = true
skalberg@14516
   219
  | is_Abs _ = false
skalberg@14516
   220
skalberg@14516
   221
fun eta_redex (t $ Bound 0) =
skalberg@14516
   222
    let
wenzelm@21588
   223
        fun free n (Bound i) = i = n
wenzelm@21588
   224
          | free n (t $ u) = free n t orelse free n u
wenzelm@21588
   225
          | free n (Abs(_,_,t)) = free (n+1) t
wenzelm@21588
   226
          | free n _ = false
skalberg@14516
   227
    in
wenzelm@21588
   228
        not (free 0 t)
skalberg@14516
   229
    end
skalberg@14516
   230
  | eta_redex _ = false
skalberg@14516
   231
wenzelm@46201
   232
fun eta_contract thy _ origt =
skalberg@14516
   233
    let
wenzelm@21588
   234
        val (typet,Tinst) = freeze_thaw_term origt
wenzelm@36945
   235
        val (init,thaw) = Drule.legacy_freeze_thaw (Thm.reflexive (cterm_of thy typet))
wenzelm@21588
   236
        val final = inst_tfrees thy Tinst o thaw
wenzelm@21588
   237
        val t = #1 (Logic.dest_equals (prop_of init))
wenzelm@21588
   238
        val _ =
wenzelm@21588
   239
            let
wenzelm@21588
   240
                val lhs = #1 (Logic.dest_equals (prop_of (final init)))
wenzelm@21588
   241
            in
wenzelm@21588
   242
                if not (lhs aconv origt)
wenzelm@32432
   243
                then
wenzelm@32432
   244
                  writeln (cat_lines
wenzelm@32432
   245
                    (["Something is utterly wrong: (orig, lhs, frozen type, t, tinst)",
wenzelm@32432
   246
                      Syntax.string_of_term_global thy origt,
wenzelm@32432
   247
                      Syntax.string_of_term_global thy lhs,
wenzelm@32432
   248
                      Syntax.string_of_term_global thy typet,
wenzelm@32432
   249
                      Syntax.string_of_term_global thy t] @
wenzelm@32432
   250
                      map (fn (n, T) => n ^ ": " ^ Syntax.string_of_typ_global thy T) Tinst))
wenzelm@21588
   251
                else ()
wenzelm@21588
   252
            end
skalberg@14516
   253
    in
wenzelm@21588
   254
        case t of
wenzelm@46201
   255
            Const("all",_) $ (Abs(x,xT,Const("==",_) $ P $ Q)) =>
wenzelm@37778
   256
            (if eta_redex P andalso eta_redex Q
wenzelm@21588
   257
              then
wenzelm@21588
   258
                  let
wenzelm@21588
   259
                      val cert = cterm_of thy
wenzelm@43324
   260
                      val v = Free (singleton (Name.variant_list (Term.add_free_names t [])) "v", xT)
wenzelm@21588
   261
                      val cv = cert v
wenzelm@21588
   262
                      val ct = cert t
wenzelm@36945
   263
                      val th = (Thm.assume ct)
wenzelm@36945
   264
                                   |> Thm.forall_elim cv
wenzelm@36945
   265
                                   |> Thm.abstract_rule x cv
wenzelm@36945
   266
                      val ext_th = Thm.eta_conversion (cert (Abs(x,xT,P)))
wenzelm@36945
   267
                      val th' = Thm.transitive (Thm.symmetric ext_th) th
wenzelm@21588
   268
                      val cu = cert (prop_of th')
wenzelm@36945
   269
                      val uth = Thm.combination (Thm.assume cu) (Thm.reflexive cv)
wenzelm@36945
   270
                      val uth' = (Thm.beta_conversion false (cert (Abs(x,xT,Q) $ v)))
wenzelm@36945
   271
                                     |> Thm.transitive uth
wenzelm@36945
   272
                                     |> Thm.forall_intr cv
wenzelm@36945
   273
                                     |> Thm.implies_intr cu
wenzelm@36945
   274
                      val rew_th = Thm.equal_intr (th' |> Thm.implies_intr ct) uth'
wenzelm@21588
   275
                      val res = final rew_th
wenzelm@21588
   276
                  in
wenzelm@21588
   277
                       SOME res
wenzelm@21588
   278
                  end
wenzelm@21588
   279
              else NONE)
wenzelm@21588
   280
          | _ => NONE
obua@17440
   281
       end
skalberg@14516
   282
wenzelm@46201
   283
fun eta_expand thy _ origt =
skalberg@14516
   284
    let
wenzelm@21588
   285
        val (typet,Tinst) = freeze_thaw_term origt
wenzelm@36945
   286
        val (init,thaw) = Drule.legacy_freeze_thaw (Thm.reflexive (cterm_of thy typet))
wenzelm@21588
   287
        val final = inst_tfrees thy Tinst o thaw
wenzelm@21588
   288
        val t = #1 (Logic.dest_equals (prop_of init))
wenzelm@21588
   289
        val _ =
wenzelm@21588
   290
            let
wenzelm@21588
   291
                val lhs = #1 (Logic.dest_equals (prop_of (final init)))
wenzelm@21588
   292
            in
wenzelm@21588
   293
                if not (lhs aconv origt)
wenzelm@32432
   294
                then
wenzelm@32432
   295
                  writeln (cat_lines
wenzelm@32432
   296
                    (["Something is utterly wrong: (orig, lhs, frozen type, t, tinst)",
wenzelm@32432
   297
                      Syntax.string_of_term_global thy origt,
wenzelm@32432
   298
                      Syntax.string_of_term_global thy lhs,
wenzelm@32432
   299
                      Syntax.string_of_term_global thy typet,
wenzelm@32432
   300
                      Syntax.string_of_term_global thy t] @
wenzelm@32432
   301
                      map (fn (n, T) => n ^ ": " ^ Syntax.string_of_typ_global thy T) Tinst))
wenzelm@21588
   302
                else ()
wenzelm@21588
   303
            end
skalberg@14516
   304
    in
wenzelm@21588
   305
        case t of
wenzelm@21588
   306
            Const("==",T) $ P $ Q =>
wenzelm@21588
   307
            if is_Abs P orelse is_Abs Q
wenzelm@21588
   308
            then (case domain_type T of
wenzelm@21588
   309
                      Type("fun",[aT,bT]) =>
wenzelm@21588
   310
                      let
wenzelm@21588
   311
                          val cert = cterm_of thy
wenzelm@43324
   312
                          val vname = singleton (Name.variant_list (Term.add_free_names t [])) "v"
wenzelm@21588
   313
                          val v = Free(vname,aT)
wenzelm@21588
   314
                          val cv = cert v
wenzelm@21588
   315
                          val ct = cert t
wenzelm@36945
   316
                          val th1 = (Thm.combination (Thm.assume ct) (Thm.reflexive cv))
wenzelm@36945
   317
                                        |> Thm.forall_intr cv
wenzelm@36945
   318
                                        |> Thm.implies_intr ct
wenzelm@21588
   319
                          val concl = cert (concl_of th1)
wenzelm@36945
   320
                          val th2 = (Thm.assume concl)
wenzelm@36945
   321
                                        |> Thm.forall_elim cv
wenzelm@36945
   322
                                        |> Thm.abstract_rule vname cv
wenzelm@21588
   323
                          val (lhs,rhs) = Logic.dest_equals (prop_of th2)
wenzelm@36945
   324
                          val elhs = Thm.eta_conversion (cert lhs)
wenzelm@36945
   325
                          val erhs = Thm.eta_conversion (cert rhs)
wenzelm@36945
   326
                          val th2' = Thm.transitive
wenzelm@36945
   327
                                         (Thm.transitive (Thm.symmetric elhs) th2)
wenzelm@21588
   328
                                         erhs
wenzelm@36945
   329
                          val res = Thm.equal_intr th1 (th2' |> Thm.implies_intr concl)
wenzelm@21588
   330
                          val res' = final res
wenzelm@21588
   331
                      in
wenzelm@21588
   332
                          SOME res'
wenzelm@21588
   333
                      end
wenzelm@21588
   334
                    | _ => NONE)
wenzelm@21588
   335
            else NONE
wenzelm@32432
   336
          | _ => error ("Bad eta_expand argument" ^ Syntax.string_of_term_global thy t)
wenzelm@32432
   337
    end;
skalberg@14516
   338
wenzelm@14854
   339
fun mk_tfree s = TFree("'"^s,[])
obua@20326
   340
fun mk_free s t = Free (s,t)
skalberg@14516
   341
val xT = mk_tfree "a"
skalberg@14516
   342
val yT = mk_tfree "b"
wenzelm@27330
   343
val x = Free ("x", xT)
wenzelm@27330
   344
val y = Free ("y", yT)
obua@20326
   345
val P  = mk_free "P" (xT-->yT-->propT)
obua@20326
   346
val Q  = mk_free "Q" (xT-->yT)
obua@20326
   347
val R  = mk_free "R" (xT-->yT)
obua@20326
   348
val S  = mk_free "S" xT
obua@20326
   349
val S'  = mk_free "S'" xT
skalberg@14516
   350
in
obua@17188
   351
wenzelm@38715
   352
fun quant_simproc thy = Simplifier.simproc_global_i
wenzelm@21588
   353
                           thy
wenzelm@21588
   354
                           "Ordered rewriting of nested quantifiers"
wenzelm@27330
   355
                           [Logic.all x (Logic.all y (P $ x $ y))]
wenzelm@21588
   356
                           quant_rewrite
wenzelm@38715
   357
fun eta_expand_simproc thy = Simplifier.simproc_global_i
wenzelm@21588
   358
                         thy
wenzelm@21588
   359
                         "Smart eta-expansion by equivalences"
wenzelm@21588
   360
                         [Logic.mk_equals(Q,R)]
wenzelm@21588
   361
                         eta_expand
wenzelm@38715
   362
fun eta_contract_simproc thy = Simplifier.simproc_global_i
wenzelm@21588
   363
                         thy
wenzelm@21588
   364
                         "Smart handling of eta-contractions"
wenzelm@27330
   365
                         [Logic.all x (Logic.mk_equals (Q $ x, R $ x))]
wenzelm@21588
   366
                         eta_contract
skalberg@14516
   367
end
skalberg@14516
   368
skalberg@14516
   369
(* Disambiguates the names of bound variables in a term, returning t
skalberg@14516
   370
== t' where all the names of bound variables in t' are unique *)
skalberg@14516
   371
haftmann@21078
   372
fun disamb_bound thy t =
skalberg@14516
   373
    let
wenzelm@21588
   374
wenzelm@21588
   375
        fun F (t $ u,idx) =
wenzelm@21588
   376
            let
wenzelm@21588
   377
                val (t',idx') = F (t,idx)
wenzelm@21588
   378
                val (u',idx'') = F (u,idx')
wenzelm@21588
   379
            in
wenzelm@21588
   380
                (t' $ u',idx'')
wenzelm@21588
   381
            end
wenzelm@21588
   382
          | F (Abs(x,xT,t),idx) =
wenzelm@21588
   383
            let
wenzelm@41491
   384
                val x' = "x" ^ string_of_int idx
wenzelm@21588
   385
                val (t',idx') = F (t,idx+1)
wenzelm@21588
   386
            in
wenzelm@21588
   387
                (Abs(x',xT,t'),idx')
wenzelm@21588
   388
            end
wenzelm@21588
   389
          | F arg = arg
wenzelm@21588
   390
        val (t',_) = F (t,0)
wenzelm@21588
   391
        val ct = cterm_of thy t
wenzelm@21588
   392
        val ct' = cterm_of thy t'
wenzelm@36945
   393
        val res = Thm.transitive (Thm.reflexive ct) (Thm.reflexive ct')
wenzelm@21588
   394
        val _ = message ("disamb_term: " ^ (string_of_thm res))
skalberg@14516
   395
    in
wenzelm@21588
   396
        res
skalberg@14516
   397
    end
skalberg@14516
   398
skalberg@14516
   399
(* Transforms a term t to some normal form t', returning the theorem t
skalberg@14516
   400
== t'.  This is originally a help function for make_equal, but might
skalberg@14516
   401
be handy in its own right, for example for indexing terms. *)
skalberg@14516
   402
skalberg@14516
   403
fun norm_term thy t =
skalberg@14516
   404
    let
wenzelm@21588
   405
        val norms = ShuffleData.get thy
wenzelm@35232
   406
        val ss = Simplifier.global_context thy empty_ss
wenzelm@21588
   407
          addsimps (map (Thm.transfer thy) norms)
haftmann@21078
   408
          addsimprocs [quant_simproc thy, eta_expand_simproc thy,eta_contract_simproc thy]
wenzelm@21588
   409
        fun chain f th =
wenzelm@21588
   410
            let
wenzelm@22902
   411
                val rhs = Thm.rhs_of th
wenzelm@21588
   412
            in
wenzelm@36945
   413
                Thm.transitive th (f rhs)
wenzelm@21588
   414
            end
wenzelm@21588
   415
        val th =
haftmann@21078
   416
            t |> disamb_bound thy
wenzelm@21588
   417
              |> chain (Simplifier.full_rewrite ss)
wenzelm@36945
   418
              |> chain Thm.eta_conversion
wenzelm@36614
   419
              |> Thm.strip_shyps
wenzelm@21588
   420
        val _ = message ("norm_term: " ^ (string_of_thm th))
skalberg@14516
   421
    in
wenzelm@21588
   422
        th
wenzelm@17463
   423
    end
skalberg@14516
   424
skalberg@14516
   425
skalberg@14516
   426
(* Closes a theorem with respect to free and schematic variables (does
skalberg@14516
   427
not touch type variables, though). *)
skalberg@14516
   428
skalberg@14516
   429
fun close_thm th =
skalberg@14516
   430
    let
wenzelm@22578
   431
        val thy = Thm.theory_of_thm th
wenzelm@21588
   432
        val c = prop_of th
wenzelm@44121
   433
        val vars = Misc_Legacy.add_term_frees (c, Misc_Legacy.add_term_vars(c,[]))
skalberg@14516
   434
    in
wenzelm@21588
   435
        Drule.forall_intr_list (map (cterm_of thy) vars) th
skalberg@14516
   436
    end
wenzelm@37778
   437
skalberg@14516
   438
skalberg@14516
   439
(* Normalizes a theorem's conclusion using norm_term. *)
skalberg@14516
   440
skalberg@14516
   441
fun norm_thm thy th =
skalberg@14516
   442
    let
wenzelm@21588
   443
        val c = prop_of th
skalberg@14516
   444
    in
wenzelm@36945
   445
        Thm.equal_elim (norm_term thy c) th
skalberg@14516
   446
    end
skalberg@14516
   447
haftmann@21078
   448
(* make_equal thy t u tries to construct the theorem t == u under the
haftmann@21078
   449
signature thy.  If it succeeds, SOME (t == u) is returned, otherwise
skalberg@15531
   450
NONE is returned. *)
skalberg@14516
   451
haftmann@21078
   452
fun make_equal thy t u =
skalberg@14516
   453
    let
wenzelm@21588
   454
        val t_is_t' = norm_term thy t
wenzelm@21588
   455
        val u_is_u' = norm_term thy u
wenzelm@36945
   456
        val th = Thm.transitive t_is_t' (Thm.symmetric u_is_u')
wenzelm@21588
   457
        val _ = message ("make_equal: SOME " ^ (string_of_thm th))
skalberg@14516
   458
    in
wenzelm@21588
   459
        SOME th
skalberg@14516
   460
    end
skalberg@15531
   461
    handle e as THM _ => (message "make_equal: NONE";NONE)
wenzelm@21588
   462
skalberg@14516
   463
fun match_consts ignore t (* th *) =
skalberg@14516
   464
    let
wenzelm@21588
   465
        fun add_consts (Const (c, _), cs) =
haftmann@36692
   466
            if member (op =) ignore c
wenzelm@21588
   467
            then cs
wenzelm@21588
   468
            else insert (op =) c cs
wenzelm@21588
   469
          | add_consts (t $ u, cs) = add_consts (t, add_consts (u, cs))
wenzelm@21588
   470
          | add_consts (Abs (_, _, t), cs) = add_consts (t, cs)
wenzelm@21588
   471
          | add_consts (_, cs) = cs
wenzelm@21588
   472
        val t_consts = add_consts(t,[])
skalberg@14516
   473
    in
skalberg@14516
   474
     fn (name,th) =>
wenzelm@21588
   475
        let
wenzelm@21588
   476
            val th_consts = add_consts(prop_of th,[])
wenzelm@21588
   477
        in
haftmann@33038
   478
            eq_set (op =) (t_consts, th_consts)
wenzelm@21588
   479
        end
skalberg@14516
   480
    end
wenzelm@21588
   481
haftmann@33040
   482
val collect_ignored = fold_rev (fn thm => fn cs =>
haftmann@33040
   483
  let
haftmann@33040
   484
    val (lhs, rhs) = Logic.dest_equals (prop_of thm);
haftmann@33040
   485
    val consts_lhs = Term.add_const_names lhs [];
haftmann@33040
   486
    val consts_rhs = Term.add_const_names rhs [];
haftmann@33040
   487
    val ignore_lhs = subtract (op =) consts_rhs consts_lhs;
haftmann@33040
   488
    val ignore_rhs = subtract (op =) consts_lhs consts_rhs;
haftmann@33040
   489
  in
haftmann@33040
   490
    fold_rev (insert (op =)) cs (ignore_lhs @ ignore_rhs)
haftmann@33040
   491
  end)
skalberg@14516
   492
skalberg@14516
   493
(* set_prop t thms tries to make a theorem with the proposition t from
skalberg@14516
   494
one of the theorems thms, by shuffling the propositions around.  If it
skalberg@15531
   495
succeeds, SOME theorem is returned, otherwise NONE.  *)
skalberg@14516
   496
skalberg@14516
   497
fun set_prop thy t =
skalberg@14516
   498
    let
wenzelm@44121
   499
        val vars = Misc_Legacy.add_term_frees (t, Misc_Legacy.add_term_vars (t,[]))
wenzelm@27330
   500
        val closed_t = fold_rev Logic.all vars t
wenzelm@21588
   501
        val rew_th = norm_term thy closed_t
wenzelm@22902
   502
        val rhs = Thm.rhs_of rew_th
skalberg@14516
   503
wenzelm@21588
   504
        fun process [] = NONE
wenzelm@21588
   505
          | process ((name,th)::thms) =
wenzelm@21588
   506
            let
wenzelm@35845
   507
                val norm_th = Thm.varifyT_global (norm_thm thy (close_thm (Thm.transfer thy th)))
wenzelm@36945
   508
                val triv_th = Thm.trivial rhs
wenzelm@21588
   509
                val _ = message ("Shuffler.set_prop: Gluing together " ^ (string_of_thm norm_th) ^ " and " ^ (string_of_thm triv_th))
wenzelm@31945
   510
                val mod_th = case Seq.pull (Thm.bicompose false (*true*) (false,norm_th,0) 1 triv_th) of
wenzelm@21588
   511
                                 SOME(th,_) => SOME th
wenzelm@21588
   512
                               | NONE => NONE
wenzelm@21588
   513
            in
wenzelm@21588
   514
                case mod_th of
wenzelm@21588
   515
                    SOME mod_th =>
wenzelm@21588
   516
                    let
wenzelm@36945
   517
                        val closed_th = Thm.equal_elim (Thm.symmetric rew_th) mod_th
wenzelm@21588
   518
                    in
wenzelm@21588
   519
                        message ("Shuffler.set_prop succeeded by " ^ name);
wenzelm@21588
   520
                        SOME (name,forall_elim_list (map (cterm_of thy) vars) closed_th)
wenzelm@21588
   521
                    end
wenzelm@21588
   522
                  | NONE => process thms
wenzelm@21588
   523
            end
wenzelm@21588
   524
            handle e as THM _ => process thms
skalberg@14516
   525
    in
wenzelm@21588
   526
        fn thms =>
wenzelm@21588
   527
           case process thms of
wenzelm@21588
   528
               res as SOME (name,th) => if (prop_of th) aconv t
wenzelm@21588
   529
                                        then res
wenzelm@21588
   530
                                        else error "Internal error in set_prop"
wenzelm@21588
   531
             | NONE => NONE
skalberg@14516
   532
    end
skalberg@14516
   533
skalberg@14516
   534
fun find_potential thy t =
skalberg@14516
   535
    let
wenzelm@21588
   536
        val shuffles = ShuffleData.get thy
wenzelm@21588
   537
        val ignored = collect_ignored shuffles []
wenzelm@26662
   538
        val all_thms =
wenzelm@39557
   539
          map (`Thm.get_name_hint) (maps #2 (Facts.dest_static [] (Global_Theory.facts_of thy)))
skalberg@14516
   540
    in
wenzelm@33317
   541
        filter (match_consts ignored t) all_thms
skalberg@14516
   542
    end
skalberg@14516
   543
wenzelm@42368
   544
fun gen_shuffle_tac ctxt search thms = SUBGOAL (fn (t, i) =>
skalberg@14516
   545
    let
wenzelm@42361
   546
        val thy = Proof_Context.theory_of ctxt
wenzelm@21588
   547
        val set = set_prop thy t
wenzelm@21588
   548
        fun process_tac thms st =
wenzelm@21588
   549
            case set thms of
wenzelm@21588
   550
                SOME (_,th) => Seq.of_list (compose (th,i,st))
wenzelm@21588
   551
              | NONE => Seq.empty
skalberg@14516
   552
    in
wenzelm@42368
   553
        process_tac thms APPEND
wenzelm@42368
   554
          (if search then process_tac (find_potential thy t) else no_tac)
wenzelm@42368
   555
    end)
skalberg@14516
   556
wenzelm@31244
   557
fun shuffle_tac ctxt thms =
wenzelm@31244
   558
  gen_shuffle_tac ctxt false (map (pair "") thms);
wenzelm@31244
   559
wenzelm@31244
   560
fun search_tac ctxt =
wenzelm@31244
   561
  gen_shuffle_tac ctxt true (map (pair "premise") (Assumption.all_prems_of ctxt));
skalberg@14516
   562
skalberg@14516
   563
fun add_shuffle_rule thm thy =
skalberg@14516
   564
    let
wenzelm@21588
   565
        val shuffles = ShuffleData.get thy
skalberg@14516
   566
    in
wenzelm@21588
   567
        if exists (curry Thm.eq_thm thm) shuffles
wenzelm@21588
   568
        then (warning ((string_of_thm thm) ^ " already known to the shuffler");
wenzelm@21588
   569
              thy)
wenzelm@21588
   570
        else ShuffleData.put (thm::shuffles) thy
skalberg@14516
   571
    end
skalberg@14516
   572
wenzelm@20897
   573
val shuffle_attr = Thm.declaration_attribute (fn th => Context.mapping (add_shuffle_rule th) I);
skalberg@14516
   574
wenzelm@18708
   575
val setup =
wenzelm@31241
   576
  Method.setup @{binding shuffle_tac}
wenzelm@31244
   577
    (Attrib.thms >> (fn ths => fn ctxt => SIMPLE_METHOD' (shuffle_tac ctxt ths)))
wenzelm@31241
   578
    "solve goal by shuffling terms around" #>
wenzelm@31241
   579
  Method.setup @{binding search_tac}
wenzelm@31241
   580
    (Scan.succeed (SIMPLE_METHOD' o search_tac)) "search for suitable theorems" #>
wenzelm@18708
   581
  add_shuffle_rule weaken #>
wenzelm@18708
   582
  add_shuffle_rule equiv_comm #>
wenzelm@18708
   583
  add_shuffle_rule imp_comm #>
wenzelm@18708
   584
  add_shuffle_rule Drule.norm_hhf_eq #>
wenzelm@18708
   585
  add_shuffle_rule Drule.triv_forall_equality #>
wenzelm@30528
   586
  Attrib.setup @{binding shuffle_rule} (Scan.succeed shuffle_attr) "declare rule for shuffler";
wenzelm@18708
   587
skalberg@14516
   588
end