src/HOL/Tools/set_comprehension_pointfree.ML
author wenzelm
Sat Jul 25 23:41:53 2015 +0200 (2015-07-25)
changeset 60781 2da59cdf531c
parent 60696 8304fb4fb823
child 61125 4c68426800de
permissions -rw-r--r--
updated to infer_instantiate;
tuned;
wenzelm@48124
     1
(*  Title:      HOL/Tools/set_comprehension_pointfree.ML
bulwahn@48049
     2
    Author:     Felix Kuperjans, Lukas Bulwahn, TU Muenchen
wenzelm@48124
     3
    Author:     Rafal Kolanski, NICTA
bulwahn@48049
     4
bulwahn@48049
     5
Simproc for rewriting set comprehensions to pointfree expressions.
bulwahn@48049
     6
*)
bulwahn@48049
     7
bulwahn@48049
     8
signature SET_COMPREHENSION_POINTFREE =
bulwahn@48049
     9
sig
wenzelm@51717
    10
  val base_simproc : Proof.context -> cterm -> thm option
wenzelm@51717
    11
  val code_simproc : Proof.context -> cterm -> thm option
wenzelm@51717
    12
  val simproc : Proof.context -> cterm -> thm option
bulwahn@48049
    13
end
bulwahn@48049
    14
bulwahn@48049
    15
structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
bulwahn@48049
    16
struct
bulwahn@48049
    17
bulwahn@48049
    18
(* syntactic operations *)
bulwahn@48049
    19
bulwahn@48049
    20
fun mk_inf (t1, t2) =
bulwahn@48049
    21
  let
bulwahn@48049
    22
    val T = fastype_of t1
bulwahn@48049
    23
  in
bulwahn@48049
    24
    Const (@{const_name Lattices.inf_class.inf}, T --> T --> T) $ t1 $ t2
bulwahn@48049
    25
  end
bulwahn@48049
    26
bulwahn@49768
    27
fun mk_sup (t1, t2) =
bulwahn@49768
    28
  let
bulwahn@49768
    29
    val T = fastype_of t1
bulwahn@49768
    30
  in
bulwahn@49768
    31
    Const (@{const_name Lattices.sup_class.sup}, T --> T --> T) $ t1 $ t2
bulwahn@49768
    32
  end
bulwahn@49768
    33
bulwahn@49768
    34
fun mk_Compl t =
bulwahn@49768
    35
  let
bulwahn@49768
    36
    val T = fastype_of t
bulwahn@49768
    37
  in
bulwahn@49768
    38
    Const (@{const_name "Groups.uminus_class.uminus"}, T --> T) $ t
bulwahn@49768
    39
  end
bulwahn@49768
    40
bulwahn@48049
    41
fun mk_image t1 t2 =
bulwahn@48049
    42
  let
bulwahn@48049
    43
    val T as Type (@{type_name fun}, [_ , R]) = fastype_of t1
bulwahn@48049
    44
  in
rafal@48108
    45
    Const (@{const_name image},
rafal@48108
    46
      T --> fastype_of t2 --> HOLogic.mk_setT R) $ t1 $ t2
bulwahn@48049
    47
  end;
bulwahn@48049
    48
bulwahn@48049
    49
fun mk_sigma (t1, t2) =
bulwahn@48049
    50
  let
bulwahn@48049
    51
    val T1 = fastype_of t1
bulwahn@48049
    52
    val T2 = fastype_of t2
bulwahn@48049
    53
    val setT = HOLogic.dest_setT T1
rafal@48108
    54
    val resT = HOLogic.mk_setT (HOLogic.mk_prodT (setT, HOLogic.dest_setT T2))
bulwahn@48049
    55
  in
rafal@48108
    56
    Const (@{const_name Sigma},
rafal@48108
    57
      T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
bulwahn@48049
    58
  end;
bulwahn@48049
    59
bulwahn@49874
    60
fun mk_vimage f s =
bulwahn@49874
    61
  let
bulwahn@49874
    62
    val T as Type (@{type_name fun}, [T1, T2]) = fastype_of f
bulwahn@49874
    63
  in
bulwahn@49874
    64
    Const (@{const_name vimage}, T --> HOLogic.mk_setT T2 --> HOLogic.mk_setT T1) $ f $ s
bulwahn@49874
    65
  end; 
bulwahn@49874
    66
bulwahn@49857
    67
fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (x, T, t)) = ((x, T), t)
bulwahn@48049
    68
  | dest_Collect t = raise TERM ("dest_Collect", [t])
bulwahn@48049
    69
bulwahn@48049
    70
(* Copied from predicate_compile_aux.ML *)
bulwahn@48049
    71
fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
bulwahn@48049
    72
  let
bulwahn@48049
    73
    val (xTs, t') = strip_ex t
bulwahn@48049
    74
  in
bulwahn@48049
    75
    ((x, T) :: xTs, t')
bulwahn@48049
    76
  end
bulwahn@48049
    77
  | strip_ex t = ([], t)
bulwahn@48049
    78
bulwahn@49849
    79
fun mk_prod1 Ts (t1, t2) =
bulwahn@49849
    80
  let
wenzelm@59058
    81
    val (T1, T2) = apply2 (curry fastype_of1 Ts) (t1, t2)
bulwahn@49849
    82
  in
bulwahn@49849
    83
    HOLogic.pair_const T1 T2 $ t1 $ t2
bulwahn@49849
    84
  end;
bulwahn@49849
    85
bulwahn@49874
    86
fun mk_split_abs vs (Bound i) t = let val (x, T) = nth vs i in Abs (x, T, t) end
blanchet@55414
    87
  | mk_split_abs vs (Const (@{const_name Product_Type.Pair}, _) $ u $ v) t =
bulwahn@49874
    88
      HOLogic.mk_split (mk_split_abs vs u (mk_split_abs vs v t))
bulwahn@49874
    89
  | mk_split_abs _ t _ = raise TERM ("mk_split_abs: bad term", [t]);
bulwahn@49874
    90
bulwahn@49901
    91
(* a variant of HOLogic.strip_psplits *)
bulwahn@49901
    92
val strip_psplits =
bulwahn@49901
    93
  let
bulwahn@49901
    94
    fun strip [] qs vs t = (t, rev vs, qs)
blanchet@55414
    95
      | strip (p :: ps) qs vs (Const (@{const_name Product_Type.prod.case_prod}, _) $ t) =
bulwahn@49901
    96
          strip ((1 :: p) :: (2 :: p) :: ps) (p :: qs) vs t
bulwahn@49957
    97
      | strip (_ :: ps) qs vs (Abs (s, T, t)) = strip ps qs ((s, T) :: vs) t
bulwahn@49957
    98
      | strip (_ :: ps) qs vs t = strip ps qs
bulwahn@49901
    99
          ((Name.uu_, hd (binder_types (fastype_of1 (map snd vs, t)))) :: vs)
bulwahn@49901
   100
          (incr_boundvars 1 t $ Bound 0)
bulwahn@49901
   101
  in strip [[]] [] [] end;
bulwahn@49849
   102
bulwahn@49849
   103
(* patterns *)
bulwahn@49849
   104
bulwahn@50024
   105
datatype pattern = Pattern of int list
bulwahn@49849
   106
bulwahn@50024
   107
fun dest_Pattern (Pattern bs) = bs
bulwahn@49849
   108
bulwahn@50024
   109
fun dest_bound (Bound i) = i
bulwahn@50024
   110
  | dest_bound t = raise TERM("dest_bound", [t]);
bulwahn@49849
   111
bulwahn@50024
   112
fun type_of_pattern Ts (Pattern bs) = HOLogic.mk_tupleT (map (nth Ts) bs)
bulwahn@50024
   113
bulwahn@50024
   114
fun term_of_pattern Ts (Pattern bs) =
bulwahn@49849
   115
    let
bulwahn@50024
   116
      fun mk [b] = Bound b
bulwahn@50024
   117
        | mk (b :: bs) = HOLogic.pair_const (nth Ts b) (type_of_pattern Ts (Pattern bs))
bulwahn@50024
   118
           $ Bound b $ mk bs
bulwahn@50024
   119
    in mk bs end;
bulwahn@49849
   120
bulwahn@49849
   121
(* formulas *)
bulwahn@49849
   122
bulwahn@49849
   123
datatype formula = Atom of (pattern * term) | Int of formula * formula | Un of formula * formula
bulwahn@49849
   124
bulwahn@49900
   125
fun map_atom f (Atom a) = Atom (f a)
bulwahn@49900
   126
  | map_atom _ x = x
bulwahn@49900
   127
bulwahn@50025
   128
fun is_collect_atom (Atom (_, Const(@{const_name Collect}, _) $ _)) = true
bulwahn@50025
   129
  | is_collect_atom (Atom (_, Const (@{const_name "Groups.uminus_class.uminus"}, _) $ (Const(@{const_name Collect}, _) $ _))) = true
bulwahn@50025
   130
  | is_collect_atom _ = false
bulwahn@50025
   131
bulwahn@50025
   132
fun mk_split _ [(x, T)] t = (T, Abs (x, T, t))
bulwahn@50025
   133
  | mk_split rT ((x, T) :: vs) t =
bulwahn@50025
   134
    let
bulwahn@50025
   135
      val (T', t') = mk_split rT vs t
bulwahn@50025
   136
      val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t'))
bulwahn@50025
   137
    in (domain_type (fastype_of t''), t'') end
bulwahn@50025
   138
bulwahn@50030
   139
fun mk_term vs t =
bulwahn@50030
   140
  let
bulwahn@50030
   141
    val bs = loose_bnos t
bulwahn@50030
   142
    val vs' = map (nth (rev vs)) bs
bulwahn@50030
   143
    val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs)
bulwahn@50030
   144
      |> sort (fn (p1, p2) => int_ord (fst p1, fst p2))
bulwahn@50030
   145
      |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst')))))
bulwahn@50030
   146
    val t' = subst_bounds (subst, t)
bulwahn@50030
   147
    val tuple = Pattern bs
bulwahn@50030
   148
  in (tuple, (vs', t')) end
bulwahn@50030
   149
bulwahn@50030
   150
fun default_atom vs t =
bulwahn@50030
   151
  let
bulwahn@50030
   152
    val (tuple, (vs', t')) = mk_term vs t
bulwahn@50030
   153
    val T = HOLogic.mk_tupleT (map snd vs')
bulwahn@50030
   154
    val s = HOLogic.Collect_const T $ (snd (mk_split @{typ bool} vs' t'))
bulwahn@50030
   155
  in
bulwahn@50030
   156
    (tuple, Atom (tuple, s))
bulwahn@50030
   157
  end
bulwahn@50030
   158
bulwahn@50032
   159
fun mk_atom vs (t as Const (@{const_name "Set.member"}, _) $ x $ s) =
bulwahn@49900
   160
    if not (null (loose_bnos s)) then
bulwahn@50032
   161
      default_atom vs t
bulwahn@49900
   162
    else
bulwahn@50030
   163
      (case try ((map dest_bound) o HOLogic.strip_tuple) x of
bulwahn@50030
   164
      SOME pat => (Pattern pat, Atom (Pattern pat, s))
bulwahn@49874
   165
    | NONE =>
bulwahn@49900
   166
        let
bulwahn@50030
   167
          val (tuple, (vs', x')) = mk_term vs x 
bulwahn@49900
   168
          val rT = HOLogic.dest_setT (fastype_of s)
bulwahn@50030
   169
          val s = mk_vimage (snd (mk_split rT vs' x')) s
bulwahn@50030
   170
        in (tuple, Atom (tuple, s)) end)
bulwahn@50030
   171
  | mk_atom vs (Const (@{const_name "HOL.Not"}, _) $ t) = apsnd (map_atom (apsnd mk_Compl)) (mk_atom vs t)
bulwahn@50030
   172
  | mk_atom vs t = default_atom vs t
bulwahn@49849
   173
bulwahn@50025
   174
fun merge' [] (pats1, pats2) = ([], (pats1, pats2))
bulwahn@50025
   175
  | merge' pat (pats, []) = (pat, (pats, []))
bulwahn@50025
   176
  | merge' pat (pats1, pats) =
bulwahn@49849
   177
  let
bulwahn@50025
   178
    fun disjoint_to_pat p = null (inter (op =) pat p)
bulwahn@50025
   179
    val overlap_pats = filter_out disjoint_to_pat pats
bulwahn@50025
   180
    val rem_pats = filter disjoint_to_pat pats
bulwahn@50025
   181
    val (pat, (pats', pats1')) = merge' (distinct (op =) (flat overlap_pats @ pat)) (rem_pats, pats1)
bulwahn@49849
   182
  in
bulwahn@50025
   183
    (pat, (pats1', pats'))
bulwahn@49849
   184
  end
bulwahn@49849
   185
bulwahn@50025
   186
fun merge ([], pats) = pats
bulwahn@50025
   187
  | merge (pat :: pats', pats) =
bulwahn@50025
   188
  let val (pat', (pats1', pats2')) = merge' pat (pats', pats)
bulwahn@50025
   189
  in pat' :: merge (pats1', pats2') end;
bulwahn@50025
   190
bulwahn@50025
   191
fun restricted_merge ([], pats) = pats
bulwahn@50025
   192
  | restricted_merge (pat :: pats', pats) =
bulwahn@50025
   193
  let
bulwahn@50025
   194
    fun disjoint_to_pat p = null (inter (op =) pat p)
bulwahn@50025
   195
    val overlap_pats = filter_out disjoint_to_pat pats
bulwahn@50025
   196
    val rem_pats = filter disjoint_to_pat pats
bulwahn@50025
   197
  in
bulwahn@50025
   198
    case overlap_pats of
bulwahn@50025
   199
      [] => pat :: restricted_merge (pats', rem_pats)
bulwahn@50025
   200
    | [pat'] => if subset (op =) (pat, pat') then
bulwahn@50025
   201
        pat' :: restricted_merge (pats', rem_pats)
bulwahn@50025
   202
      else if subset (op =) (pat', pat) then
bulwahn@50025
   203
        pat :: restricted_merge (pats', rem_pats)
bulwahn@50025
   204
      else error "restricted merge: two patterns require relational join"
bulwahn@50025
   205
    | _ => error "restricted merge: multiple patterns overlap"
bulwahn@50025
   206
  end;
bulwahn@50025
   207
  
bulwahn@50025
   208
fun map_atoms f (Atom a) = Atom (f a)
wenzelm@59058
   209
  | map_atoms f (Un (fm1, fm2)) = Un (apply2 (map_atoms f) (fm1, fm2))
wenzelm@59058
   210
  | map_atoms f (Int (fm1, fm2)) = Int (apply2 (map_atoms f) (fm1, fm2))
bulwahn@50025
   211
bulwahn@50028
   212
fun extend Ts bs t = foldr1 mk_sigma (t :: map (fn b => HOLogic.mk_UNIV (nth Ts b)) bs)
bulwahn@49849
   213
bulwahn@50025
   214
fun rearrange vs (pat, pat') t =
bulwahn@50025
   215
  let
bulwahn@50025
   216
    val subst = map_index (fn (i, b) => (b, i)) (rev pat)
bulwahn@50025
   217
    val vs' = map (nth (rev vs)) pat
bulwahn@50025
   218
    val Ts' = map snd (rev vs')
bulwahn@50025
   219
    val bs = map (fn b => the (AList.lookup (op =) subst b)) pat'
bulwahn@50025
   220
    val rt = term_of_pattern Ts' (Pattern bs)
bulwahn@50025
   221
    val rT = type_of_pattern Ts' (Pattern bs)
bulwahn@50025
   222
    val (_, f) = mk_split rT vs' rt
bulwahn@50025
   223
  in
bulwahn@50025
   224
    mk_image f t
bulwahn@50025
   225
  end;
bulwahn@50025
   226
bulwahn@50025
   227
fun adjust vs pats (Pattern pat, t) =
bulwahn@50025
   228
  let
bulwahn@50025
   229
    val SOME p = find_first (fn p => not (null (inter (op =) pat p))) pats
bulwahn@50025
   230
    val missing = subtract (op =) pat p
bulwahn@50025
   231
    val Ts = rev (map snd vs)
bulwahn@50025
   232
    val t' = extend Ts missing t
bulwahn@50025
   233
  in (Pattern p, rearrange vs (pat @ missing, p) t') end
bulwahn@49849
   234
bulwahn@50025
   235
fun adjust_atoms vs pats fm = map_atoms (adjust vs pats) fm
bulwahn@50025
   236
bulwahn@50025
   237
fun merge_inter vs (pats1, fm1) (pats2, fm2) =
bulwahn@50025
   238
  let
bulwahn@50025
   239
    val pats = restricted_merge (map dest_Pattern pats1, map dest_Pattern pats2) 
wenzelm@59058
   240
    val (fm1', fm2') = apply2 (adjust_atoms vs pats) (fm1, fm2)
bulwahn@50025
   241
  in
bulwahn@50025
   242
    (map Pattern pats, Int (fm1', fm2'))
bulwahn@50025
   243
  end;
bulwahn@50025
   244
bulwahn@50025
   245
fun merge_union vs (pats1, fm1) (pats2, fm2) = 
bulwahn@50025
   246
  let
bulwahn@50025
   247
    val pats = merge (map dest_Pattern pats1, map dest_Pattern pats2)
wenzelm@59058
   248
    val (fm1', fm2') = apply2 (adjust_atoms vs pats) (fm1, fm2)
bulwahn@50025
   249
  in
bulwahn@50025
   250
    (map Pattern pats, Un (fm1', fm2'))
bulwahn@50025
   251
  end;
bulwahn@50025
   252
bulwahn@50025
   253
fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge_inter vs (mk_formula vs t1) (mk_formula vs t2)
bulwahn@50025
   254
  | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge_union vs (mk_formula vs t1) (mk_formula vs t2)
bulwahn@49874
   255
  | mk_formula vs t = apfst single (mk_atom vs t)
bulwahn@49849
   256
bulwahn@49852
   257
fun strip_Int (Int (fm1, fm2)) = fm1 :: (strip_Int fm2) 
bulwahn@49852
   258
  | strip_Int fm = [fm]
bulwahn@49849
   259
bulwahn@49849
   260
(* term construction *)
bulwahn@49849
   261
bulwahn@49849
   262
fun reorder_bounds pats t =
bulwahn@49849
   263
  let
bulwahn@50024
   264
    val bounds = maps dest_Pattern pats
bulwahn@49849
   265
    val bperm = bounds ~~ ((length bounds - 1) downto 0)
bulwahn@49849
   266
      |> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
bulwahn@49849
   267
  in
bulwahn@49849
   268
    subst_bounds (map Bound bperm, t)
bulwahn@49849
   269
  end;
bulwahn@49849
   270
bulwahn@50025
   271
fun is_reordering t =
bulwahn@50025
   272
  let val (t', _, _) = HOLogic.strip_psplits t
bulwahn@50025
   273
  in forall (fn Bound _ => true) (HOLogic.strip_tuple t') end
bulwahn@50025
   274
bulwahn@48049
   275
fun mk_pointfree_expr t =
bulwahn@48049
   276
  let
bulwahn@49857
   277
    val ((x, T), (vs, t'')) = apsnd strip_ex (dest_Collect t)
bulwahn@49849
   278
    val Ts = map snd (rev vs)
bulwahn@49849
   279
    fun mk_mem_UNIV n = HOLogic.mk_mem (Bound n, HOLogic.mk_UNIV (nth Ts n))
bulwahn@49849
   280
    fun lookup (pat', t) pat = if pat = pat' then t else HOLogic.mk_UNIV (type_of_pattern Ts pat)
bulwahn@49761
   281
    val conjs = HOLogic.dest_conj t''
bulwahn@49857
   282
    val refl = HOLogic.eq_const T $ Bound (length vs) $ Bound (length vs)
bulwahn@49761
   283
    val is_the_eq =
bulwahn@49761
   284
      the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
bulwahn@49857
   285
    val eq = the_default refl (find_first is_the_eq conjs)
bulwahn@49761
   286
    val f = snd (HOLogic.dest_eq eq)
bulwahn@49761
   287
    val conjs' = filter_out (fn t => eq = t) conjs
bulwahn@49849
   288
    val unused_bounds = subtract (op =) (distinct (op =) (maps loose_bnos conjs'))
bulwahn@49849
   289
      (0 upto (length vs - 1))
bulwahn@49849
   290
    val (pats, fm) =
bulwahn@49943
   291
      mk_formula ((x, T) :: vs) (foldr1 HOLogic.mk_conj (conjs' @ map mk_mem_UNIV unused_bounds))
bulwahn@50031
   292
    fun mk_set (Atom pt) = foldr1 mk_sigma (map (lookup pt) pats)
bulwahn@49849
   293
      | mk_set (Un (f1, f2)) = mk_sup (mk_set f1, mk_set f2)
bulwahn@49849
   294
      | mk_set (Int (f1, f2)) = mk_inf (mk_set f1, mk_set f2)
bulwahn@49849
   295
    val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
bulwahn@49857
   296
    val t = mk_split_abs (rev ((x, T) :: vs)) pat (reorder_bounds pats f)
bulwahn@48049
   297
  in
bulwahn@50025
   298
    if the_default false (try is_reordering t) andalso is_collect_atom fm then
bulwahn@50025
   299
      error "mk_pointfree_expr: trivial case" 
bulwahn@50025
   300
    else (fm, mk_image t (mk_set fm))
bulwahn@48049
   301
  end;
bulwahn@48049
   302
rafal@48108
   303
val rewrite_term = try mk_pointfree_expr
rafal@48108
   304
bulwahn@49849
   305
bulwahn@48049
   306
(* proof tactic *)
bulwahn@48049
   307
blanchet@55414
   308
val case_prod_distrib = @{lemma "(case_prod g x) z = case_prod (% x y. (g x y) z) x" by (simp add: case_prod_beta)}
bulwahn@49849
   309
bulwahn@49944
   310
val vimageI2' = @{lemma "f a \<notin> A ==> a \<notin> f -` A" by simp}
bulwahn@49944
   311
val vimageE' =
bulwahn@49944
   312
  @{lemma "a \<notin> f -` B ==> (\<And> x. f a = x ==> x \<notin> B ==> P) ==> P" by simp}
bulwahn@49944
   313
bulwahn@50025
   314
val collectI' = @{lemma "\<not> P a ==> a \<notin> {x. P x}" by auto}
bulwahn@50025
   315
val collectE' = @{lemma "a \<notin> {x. P x} ==> (\<not> P a ==> Q) ==> Q" by auto}
bulwahn@50025
   316
wenzelm@59498
   317
fun elim_Collect_tac ctxt =
wenzelm@59498
   318
  dresolve_tac ctxt @{thms iffD1 [OF mem_Collect_eq]}
wenzelm@59498
   319
  THEN' (REPEAT_DETERM o (eresolve_tac ctxt @{thms exE}))
wenzelm@59498
   320
  THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms conjE}
wenzelm@51798
   321
  THEN' TRY o hyp_subst_tac ctxt;
bulwahn@48049
   322
wenzelm@59498
   323
fun intro_image_tac ctxt =
wenzelm@59498
   324
  resolve_tac ctxt @{thms image_eqI}
wenzelm@59498
   325
  THEN' (REPEAT_DETERM1 o
wenzelm@59498
   326
      (resolve_tac ctxt @{thms refl}
wenzelm@59498
   327
      ORELSE' resolve_tac ctxt @{thms arg_cong2 [OF refl, where f = "op =", OF prod.case, THEN iffD2]}
bulwahn@49849
   328
      ORELSE' CONVERSION (Conv.params_conv ~1 (K (Conv.concl_conv ~1
wenzelm@51315
   329
        (HOLogic.Trueprop_conv
blanchet@55414
   330
          (HOLogic.eq_conv Conv.all_conv (Conv.rewr_conv (mk_meta_eq case_prod_distrib)))))) ctxt)))
bulwahn@48049
   331
wenzelm@59498
   332
fun elim_image_tac ctxt =
wenzelm@59498
   333
  eresolve_tac ctxt @{thms imageE}
bulwahn@50028
   334
  THEN' REPEAT_DETERM o CHANGED o
blanchet@55642
   335
    (TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms split_paired_all prod.case})
wenzelm@59498
   336
    THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject}
wenzelm@51798
   337
    THEN' TRY o hyp_subst_tac ctxt)
bulwahn@48049
   338
wenzelm@51717
   339
fun tac1_of_formula ctxt (Int (fm1, fm2)) =
wenzelm@59498
   340
    TRY o eresolve_tac ctxt @{thms conjE}
wenzelm@59498
   341
    THEN' resolve_tac ctxt @{thms IntI}
wenzelm@51717
   342
    THEN' (fn i => tac1_of_formula ctxt fm2 (i + 1))
wenzelm@51717
   343
    THEN' tac1_of_formula ctxt fm1
wenzelm@51717
   344
  | tac1_of_formula ctxt (Un (fm1, fm2)) =
wenzelm@59498
   345
    eresolve_tac ctxt @{thms disjE} THEN' resolve_tac ctxt @{thms UnI1}
wenzelm@51717
   346
    THEN' tac1_of_formula ctxt fm1
wenzelm@59498
   347
    THEN' resolve_tac ctxt @{thms UnI2}
wenzelm@51717
   348
    THEN' tac1_of_formula ctxt fm2
wenzelm@51717
   349
  | tac1_of_formula ctxt (Atom _) =
wenzelm@58963
   350
    REPEAT_DETERM1 o (assume_tac ctxt
wenzelm@59498
   351
      ORELSE' resolve_tac ctxt @{thms SigmaI}
wenzelm@59498
   352
      ORELSE' ((resolve_tac ctxt @{thms CollectI} ORELSE' resolve_tac ctxt [collectI']) THEN'
blanchet@55642
   353
        TRY o simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}]))
wenzelm@59498
   354
      ORELSE' ((resolve_tac ctxt @{thms vimageI2} ORELSE' resolve_tac ctxt [vimageI2']) THEN'
blanchet@55642
   355
        TRY o simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}]))
wenzelm@59498
   356
      ORELSE' (resolve_tac ctxt @{thms image_eqI} THEN'
bulwahn@50025
   357
    (REPEAT_DETERM o
wenzelm@59498
   358
      (resolve_tac ctxt @{thms refl}
wenzelm@59498
   359
      ORELSE' resolve_tac ctxt @{thms arg_cong2[OF refl, where f = "op =", OF prod.case, THEN iffD2]})))
wenzelm@59498
   360
      ORELSE' resolve_tac ctxt @{thms UNIV_I}
wenzelm@59498
   361
      ORELSE' resolve_tac ctxt @{thms iffD2[OF Compl_iff]}
wenzelm@58963
   362
      ORELSE' assume_tac ctxt)
bulwahn@49849
   363
wenzelm@51717
   364
fun tac2_of_formula ctxt (Int (fm1, fm2)) =
wenzelm@59498
   365
    TRY o eresolve_tac ctxt @{thms IntE}
wenzelm@59498
   366
    THEN' TRY o resolve_tac ctxt @{thms conjI}
wenzelm@51717
   367
    THEN' (fn i => tac2_of_formula ctxt fm2 (i + 1))
wenzelm@51717
   368
    THEN' tac2_of_formula ctxt fm1
wenzelm@51717
   369
  | tac2_of_formula ctxt (Un (fm1, fm2)) =
wenzelm@59498
   370
    eresolve_tac ctxt @{thms UnE} THEN' resolve_tac ctxt @{thms disjI1}
wenzelm@51717
   371
    THEN' tac2_of_formula ctxt fm1
wenzelm@59498
   372
    THEN' resolve_tac ctxt @{thms disjI2}
wenzelm@51717
   373
    THEN' tac2_of_formula ctxt fm2
wenzelm@51717
   374
  | tac2_of_formula ctxt (Atom _) =
bulwahn@50025
   375
    REPEAT_DETERM o
wenzelm@58963
   376
      (assume_tac ctxt
wenzelm@59498
   377
       ORELSE' dresolve_tac ctxt @{thms iffD1[OF mem_Sigma_iff]}
wenzelm@59498
   378
       ORELSE' eresolve_tac ctxt @{thms conjE}
wenzelm@59498
   379
       ORELSE' ((eresolve_tac ctxt @{thms CollectE} ORELSE' eresolve_tac ctxt [collectE']) THEN'
blanchet@55642
   380
         TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}]) THEN'
wenzelm@59498
   381
         REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject} THEN' TRY o hyp_subst_tac ctxt THEN'
wenzelm@59498
   382
         TRY o resolve_tac ctxt @{thms refl})
wenzelm@59498
   383
       ORELSE' (eresolve_tac ctxt @{thms imageE}
bulwahn@50028
   384
         THEN' (REPEAT_DETERM o CHANGED o
blanchet@55642
   385
         (TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms split_paired_all prod.case})
wenzelm@59498
   386
         THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject}
wenzelm@59498
   387
         THEN' TRY o hyp_subst_tac ctxt THEN' TRY o resolve_tac ctxt @{thms refl})))
wenzelm@59498
   388
       ORELSE' eresolve_tac ctxt @{thms ComplE}
wenzelm@59498
   389
       ORELSE' ((eresolve_tac ctxt @{thms vimageE} ORELSE' eresolve_tac ctxt [vimageE'])
blanchet@55642
   390
        THEN' TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}])
wenzelm@59498
   391
        THEN' TRY o hyp_subst_tac ctxt THEN' TRY o resolve_tac ctxt @{thms refl}))
bulwahn@49849
   392
wenzelm@51717
   393
fun tac ctxt fm =
rafal@48108
   394
  let
wenzelm@59498
   395
    val subset_tac1 = resolve_tac ctxt @{thms subsetI}
wenzelm@51717
   396
      THEN' elim_Collect_tac ctxt
bulwahn@50029
   397
      THEN' intro_image_tac ctxt
wenzelm@51717
   398
      THEN' tac1_of_formula ctxt fm
wenzelm@59498
   399
    val subset_tac2 = resolve_tac ctxt @{thms subsetI}
wenzelm@51717
   400
      THEN' elim_image_tac ctxt
wenzelm@59498
   401
      THEN' resolve_tac ctxt @{thms iffD2[OF mem_Collect_eq]}
wenzelm@59498
   402
      THEN' REPEAT_DETERM o resolve_tac ctxt @{thms exI}
wenzelm@59498
   403
      THEN' (TRY o REPEAT_ALL_NEW (resolve_tac ctxt @{thms conjI}))
wenzelm@59498
   404
      THEN' (K (TRY (FIRSTGOAL ((TRY o hyp_subst_tac ctxt) THEN' resolve_tac ctxt @{thms refl}))))
bulwahn@49852
   405
      THEN' (fn i => EVERY (rev (map_index (fn (j, f) =>
wenzelm@59498
   406
        REPEAT_DETERM (eresolve_tac ctxt @{thms IntE} (i + j)) THEN
wenzelm@58839
   407
        tac2_of_formula ctxt f (i + j)) (strip_Int fm))))
rafal@48108
   408
  in
wenzelm@59498
   409
    resolve_tac ctxt @{thms subset_antisym} THEN' subset_tac1 THEN' subset_tac2
rafal@48108
   410
  end;
rafal@48108
   411
bulwahn@49849
   412
bulwahn@49896
   413
(* preprocessing conversion:
bulwahn@49896
   414
  rewrites {(x1, ..., xn). P x1 ... xn} to {(x1, ..., xn) | x1 ... xn. P x1 ... xn} *)
bulwahn@49896
   415
wenzelm@51717
   416
fun comprehension_conv ctxt ct =
wenzelm@51717
   417
  let
wenzelm@51717
   418
    fun dest_Collect (Const (@{const_name Collect}, T) $ t) = (HOLogic.dest_setT (body_type T), t)
wenzelm@51717
   419
      | dest_Collect t = raise TERM ("dest_Collect", [t])
wenzelm@51717
   420
    fun list_ex vs t = fold_rev (fn (x, T) => fn t => HOLogic.exists_const T $ Abs (x, T, t)) vs t
wenzelm@51717
   421
    fun mk_term t =
wenzelm@51717
   422
      let
wenzelm@51717
   423
        val (T, t') = dest_Collect t
wenzelm@51717
   424
        val (t'', vs, fp) = case strip_psplits t' of
wenzelm@51717
   425
            (_, [_], _) => raise TERM("mk_term", [t'])
wenzelm@51717
   426
          | (t'', vs, fp) => (t'', vs, fp)
wenzelm@51717
   427
        val Ts = map snd vs
wenzelm@51717
   428
        val eq = HOLogic.eq_const T $ Bound (length Ts) $
wenzelm@51717
   429
          (HOLogic.mk_ptuple fp (HOLogic.mk_ptupleT fp Ts) (rev (map_index (fn (i, _) => Bound i) Ts)))
wenzelm@51717
   430
      in
wenzelm@51717
   431
        HOLogic.Collect_const T $ absdummy T (list_ex vs (HOLogic.mk_conj (eq, t'')))
wenzelm@51717
   432
      end;
wenzelm@59582
   433
    fun is_eq th = is_some (try (HOLogic.dest_eq o HOLogic.dest_Trueprop) (Thm.prop_of th))
blanchet@55642
   434
    val unfold_thms = @{thms split_paired_all mem_Collect_eq prod.case}
wenzelm@51717
   435
    fun tac ctxt = 
wenzelm@59498
   436
      resolve_tac ctxt @{thms set_eqI}
wenzelm@51717
   437
      THEN' simp_tac (put_simpset HOL_basic_ss ctxt addsimps unfold_thms)
wenzelm@59498
   438
      THEN' resolve_tac ctxt @{thms iffI}
wenzelm@59498
   439
      THEN' REPEAT_DETERM o resolve_tac ctxt @{thms exI}
wenzelm@59498
   440
      THEN' resolve_tac ctxt @{thms conjI} THEN' resolve_tac ctxt @{thms refl} THEN' assume_tac ctxt
wenzelm@59498
   441
      THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms exE}
wenzelm@59498
   442
      THEN' eresolve_tac ctxt @{thms conjE}
wenzelm@59498
   443
      THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject}
wenzelm@60696
   444
      THEN' Subgoal.FOCUS (fn {prems, context = ctxt', ...} =>
wenzelm@60696
   445
        simp_tac (put_simpset HOL_basic_ss ctxt' addsimps (filter is_eq prems)) 1) ctxt
wenzelm@58963
   446
      THEN' TRY o assume_tac ctxt
wenzelm@51717
   447
  in
wenzelm@59582
   448
    case try mk_term (Thm.term_of ct) of
wenzelm@51717
   449
      NONE => Thm.reflexive ct
wenzelm@51717
   450
    | SOME t' =>
wenzelm@59582
   451
      Goal.prove ctxt [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq (Thm.term_of ct, t')))
wenzelm@51717
   452
          (fn {context, ...} => tac context 1)
wenzelm@51717
   453
        RS @{thm eq_reflection}
wenzelm@51717
   454
  end
bulwahn@49896
   455
bulwahn@49896
   456
bulwahn@49849
   457
(* main simprocs *)
bulwahn@49849
   458
bulwahn@49942
   459
val prep_thms =
bulwahn@49942
   460
  map mk_meta_eq ([@{thm Bex_def}, @{thm Pow_iff[symmetric]}] @ @{thms ex_simps[symmetric]})
bulwahn@49873
   461
bulwahn@49850
   462
val post_thms =
bulwahn@49850
   463
  map mk_meta_eq [@{thm Times_Un_distrib1[symmetric]},
bulwahn@49850
   464
  @{lemma "A \<times> B \<union> A \<times> C = A \<times> (B \<union> C)" by auto},
bulwahn@49850
   465
  @{lemma "(A \<times> B \<inter> C \<times> D) = (A \<inter> C) \<times> (B \<inter> D)" by auto}]
bulwahn@49850
   466
wenzelm@51717
   467
fun conv ctxt t =
rafal@48108
   468
  let
bulwahn@50026
   469
    val ([t'], ctxt') = Variable.import_terms true [t] (Variable.declare_term t ctxt)
wenzelm@59621
   470
    val ct = Thm.cterm_of ctxt' t'
bulwahn@49957
   471
    fun unfold_conv thms =
bulwahn@49957
   472
      Raw_Simplifier.rewrite_cterm (false, false, false) (K (K NONE))
wenzelm@51717
   473
        (empty_simpset ctxt' addsimps thms)
wenzelm@51717
   474
    val prep_eq = (comprehension_conv ctxt' then_conv unfold_conv prep_thms) ct
wenzelm@59582
   475
    val t'' = Thm.term_of (Thm.rhs_of prep_eq)
bulwahn@50026
   476
    fun mk_thm (fm, t''') = Goal.prove ctxt' [] []
wenzelm@51717
   477
      (HOLogic.mk_Trueprop (HOLogic.mk_eq (t'', t'''))) (fn {context, ...} => tac context fm 1)
bulwahn@49873
   478
    fun unfold th = th RS ((prep_eq RS meta_eq_to_obj_eq) RS @{thm trans})
wenzelm@51315
   479
    val post =
wenzelm@51315
   480
      Conv.fconv_rule
wenzelm@51315
   481
        (HOLogic.Trueprop_conv (HOLogic.eq_conv Conv.all_conv (unfold_conv post_thms)))
bulwahn@50026
   482
    val export = singleton (Variable.export ctxt' ctxt)
rafal@48108
   483
  in
bulwahn@50026
   484
    Option.map (export o post o unfold o mk_thm) (rewrite_term t'')
rafal@48108
   485
  end;
bulwahn@48049
   486
wenzelm@51717
   487
fun base_simproc ctxt redex =
bulwahn@48122
   488
  let
wenzelm@59582
   489
    val set_compr = Thm.term_of redex
bulwahn@48122
   490
  in
wenzelm@51717
   491
    conv ctxt set_compr
bulwahn@48122
   492
    |> Option.map (fn thm => thm RS @{thm eq_reflection})
bulwahn@48122
   493
  end;
bulwahn@48122
   494
bulwahn@49763
   495
fun instantiate_arg_cong ctxt pred =
bulwahn@49763
   496
  let
bulwahn@49831
   497
    val arg_cong = Thm.incr_indexes (maxidx_of_term pred + 1) @{thm arg_cong}
wenzelm@60781
   498
    val (Var (f, _) $ _, _) = HOLogic.dest_eq (HOLogic.dest_Trueprop (Thm.concl_of arg_cong))
bulwahn@49763
   499
  in
wenzelm@60781
   500
    infer_instantiate ctxt [(f, Thm.cterm_of ctxt pred)] arg_cong
bulwahn@49763
   501
  end;
bulwahn@49763
   502
wenzelm@51717
   503
fun simproc ctxt redex =
bulwahn@48049
   504
  let
wenzelm@59582
   505
    val pred $ set_compr = Thm.term_of redex
bulwahn@49763
   506
    val arg_cong' = instantiate_arg_cong ctxt pred
bulwahn@48049
   507
  in
wenzelm@51717
   508
    conv ctxt set_compr
bulwahn@49763
   509
    |> Option.map (fn thm => thm RS arg_cong' RS @{thm eq_reflection})
rafal@48108
   510
  end;
bulwahn@48049
   511
wenzelm@51717
   512
fun code_simproc ctxt redex =
bulwahn@48122
   513
  let
bulwahn@50033
   514
    fun unfold_conv thms =
bulwahn@50033
   515
      Raw_Simplifier.rewrite_cterm (false, false, false) (K (K NONE))
wenzelm@51717
   516
        (empty_simpset ctxt addsimps thms)
bulwahn@50033
   517
    val prep_thm = unfold_conv @{thms eq_equal[symmetric]} redex
bulwahn@48122
   518
  in
wenzelm@51717
   519
    case base_simproc ctxt (Thm.rhs_of prep_thm) of
bulwahn@48122
   520
      SOME rewr_thm => SOME (transitive_thm OF [transitive_thm OF [prep_thm, rewr_thm],
bulwahn@50033
   521
        unfold_conv @{thms eq_equal} (Thm.rhs_of rewr_thm)])
bulwahn@48122
   522
    | NONE => NONE
bulwahn@48122
   523
  end;
bulwahn@48122
   524
bulwahn@48049
   525
end;