src/HOL/Tools/set_comprehension_pointfree.ML
author wenzelm
Wed Jul 08 21:33:00 2015 +0200 (2015-07-08)
changeset 60696 8304fb4fb823
parent 59621 291934bac95e
child 60781 2da59cdf531c
permissions -rw-r--r--
clarified context;
     1 (*  Title:      HOL/Tools/set_comprehension_pointfree.ML
     2     Author:     Felix Kuperjans, Lukas Bulwahn, TU Muenchen
     3     Author:     Rafal Kolanski, NICTA
     4 
     5 Simproc for rewriting set comprehensions to pointfree expressions.
     6 *)
     7 
     8 signature SET_COMPREHENSION_POINTFREE =
     9 sig
    10   val base_simproc : Proof.context -> cterm -> thm option
    11   val code_simproc : Proof.context -> cterm -> thm option
    12   val simproc : Proof.context -> cterm -> thm option
    13 end
    14 
    15 structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
    16 struct
    17 
    18 (* syntactic operations *)
    19 
    20 fun mk_inf (t1, t2) =
    21   let
    22     val T = fastype_of t1
    23   in
    24     Const (@{const_name Lattices.inf_class.inf}, T --> T --> T) $ t1 $ t2
    25   end
    26 
    27 fun mk_sup (t1, t2) =
    28   let
    29     val T = fastype_of t1
    30   in
    31     Const (@{const_name Lattices.sup_class.sup}, T --> T --> T) $ t1 $ t2
    32   end
    33 
    34 fun mk_Compl t =
    35   let
    36     val T = fastype_of t
    37   in
    38     Const (@{const_name "Groups.uminus_class.uminus"}, T --> T) $ t
    39   end
    40 
    41 fun mk_image t1 t2 =
    42   let
    43     val T as Type (@{type_name fun}, [_ , R]) = fastype_of t1
    44   in
    45     Const (@{const_name image},
    46       T --> fastype_of t2 --> HOLogic.mk_setT R) $ t1 $ t2
    47   end;
    48 
    49 fun mk_sigma (t1, t2) =
    50   let
    51     val T1 = fastype_of t1
    52     val T2 = fastype_of t2
    53     val setT = HOLogic.dest_setT T1
    54     val resT = HOLogic.mk_setT (HOLogic.mk_prodT (setT, HOLogic.dest_setT T2))
    55   in
    56     Const (@{const_name Sigma},
    57       T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
    58   end;
    59 
    60 fun mk_vimage f s =
    61   let
    62     val T as Type (@{type_name fun}, [T1, T2]) = fastype_of f
    63   in
    64     Const (@{const_name vimage}, T --> HOLogic.mk_setT T2 --> HOLogic.mk_setT T1) $ f $ s
    65   end; 
    66 
    67 fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (x, T, t)) = ((x, T), t)
    68   | dest_Collect t = raise TERM ("dest_Collect", [t])
    69 
    70 (* Copied from predicate_compile_aux.ML *)
    71 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
    72   let
    73     val (xTs, t') = strip_ex t
    74   in
    75     ((x, T) :: xTs, t')
    76   end
    77   | strip_ex t = ([], t)
    78 
    79 fun mk_prod1 Ts (t1, t2) =
    80   let
    81     val (T1, T2) = apply2 (curry fastype_of1 Ts) (t1, t2)
    82   in
    83     HOLogic.pair_const T1 T2 $ t1 $ t2
    84   end;
    85 
    86 fun mk_split_abs vs (Bound i) t = let val (x, T) = nth vs i in Abs (x, T, t) end
    87   | mk_split_abs vs (Const (@{const_name Product_Type.Pair}, _) $ u $ v) t =
    88       HOLogic.mk_split (mk_split_abs vs u (mk_split_abs vs v t))
    89   | mk_split_abs _ t _ = raise TERM ("mk_split_abs: bad term", [t]);
    90 
    91 (* a variant of HOLogic.strip_psplits *)
    92 val strip_psplits =
    93   let
    94     fun strip [] qs vs t = (t, rev vs, qs)
    95       | strip (p :: ps) qs vs (Const (@{const_name Product_Type.prod.case_prod}, _) $ t) =
    96           strip ((1 :: p) :: (2 :: p) :: ps) (p :: qs) vs t
    97       | strip (_ :: ps) qs vs (Abs (s, T, t)) = strip ps qs ((s, T) :: vs) t
    98       | strip (_ :: ps) qs vs t = strip ps qs
    99           ((Name.uu_, hd (binder_types (fastype_of1 (map snd vs, t)))) :: vs)
   100           (incr_boundvars 1 t $ Bound 0)
   101   in strip [[]] [] [] end;
   102 
   103 (* patterns *)
   104 
   105 datatype pattern = Pattern of int list
   106 
   107 fun dest_Pattern (Pattern bs) = bs
   108 
   109 fun dest_bound (Bound i) = i
   110   | dest_bound t = raise TERM("dest_bound", [t]);
   111 
   112 fun type_of_pattern Ts (Pattern bs) = HOLogic.mk_tupleT (map (nth Ts) bs)
   113 
   114 fun term_of_pattern Ts (Pattern bs) =
   115     let
   116       fun mk [b] = Bound b
   117         | mk (b :: bs) = HOLogic.pair_const (nth Ts b) (type_of_pattern Ts (Pattern bs))
   118            $ Bound b $ mk bs
   119     in mk bs end;
   120 
   121 (* formulas *)
   122 
   123 datatype formula = Atom of (pattern * term) | Int of formula * formula | Un of formula * formula
   124 
   125 fun map_atom f (Atom a) = Atom (f a)
   126   | map_atom _ x = x
   127 
   128 fun is_collect_atom (Atom (_, Const(@{const_name Collect}, _) $ _)) = true
   129   | is_collect_atom (Atom (_, Const (@{const_name "Groups.uminus_class.uminus"}, _) $ (Const(@{const_name Collect}, _) $ _))) = true
   130   | is_collect_atom _ = false
   131 
   132 fun mk_split _ [(x, T)] t = (T, Abs (x, T, t))
   133   | mk_split rT ((x, T) :: vs) t =
   134     let
   135       val (T', t') = mk_split rT vs t
   136       val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t'))
   137     in (domain_type (fastype_of t''), t'') end
   138 
   139 fun mk_term vs t =
   140   let
   141     val bs = loose_bnos t
   142     val vs' = map (nth (rev vs)) bs
   143     val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs)
   144       |> sort (fn (p1, p2) => int_ord (fst p1, fst p2))
   145       |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst')))))
   146     val t' = subst_bounds (subst, t)
   147     val tuple = Pattern bs
   148   in (tuple, (vs', t')) end
   149 
   150 fun default_atom vs t =
   151   let
   152     val (tuple, (vs', t')) = mk_term vs t
   153     val T = HOLogic.mk_tupleT (map snd vs')
   154     val s = HOLogic.Collect_const T $ (snd (mk_split @{typ bool} vs' t'))
   155   in
   156     (tuple, Atom (tuple, s))
   157   end
   158 
   159 fun mk_atom vs (t as Const (@{const_name "Set.member"}, _) $ x $ s) =
   160     if not (null (loose_bnos s)) then
   161       default_atom vs t
   162     else
   163       (case try ((map dest_bound) o HOLogic.strip_tuple) x of
   164       SOME pat => (Pattern pat, Atom (Pattern pat, s))
   165     | NONE =>
   166         let
   167           val (tuple, (vs', x')) = mk_term vs x 
   168           val rT = HOLogic.dest_setT (fastype_of s)
   169           val s = mk_vimage (snd (mk_split rT vs' x')) s
   170         in (tuple, Atom (tuple, s)) end)
   171   | mk_atom vs (Const (@{const_name "HOL.Not"}, _) $ t) = apsnd (map_atom (apsnd mk_Compl)) (mk_atom vs t)
   172   | mk_atom vs t = default_atom vs t
   173 
   174 fun merge' [] (pats1, pats2) = ([], (pats1, pats2))
   175   | merge' pat (pats, []) = (pat, (pats, []))
   176   | merge' pat (pats1, pats) =
   177   let
   178     fun disjoint_to_pat p = null (inter (op =) pat p)
   179     val overlap_pats = filter_out disjoint_to_pat pats
   180     val rem_pats = filter disjoint_to_pat pats
   181     val (pat, (pats', pats1')) = merge' (distinct (op =) (flat overlap_pats @ pat)) (rem_pats, pats1)
   182   in
   183     (pat, (pats1', pats'))
   184   end
   185 
   186 fun merge ([], pats) = pats
   187   | merge (pat :: pats', pats) =
   188   let val (pat', (pats1', pats2')) = merge' pat (pats', pats)
   189   in pat' :: merge (pats1', pats2') end;
   190 
   191 fun restricted_merge ([], pats) = pats
   192   | restricted_merge (pat :: pats', pats) =
   193   let
   194     fun disjoint_to_pat p = null (inter (op =) pat p)
   195     val overlap_pats = filter_out disjoint_to_pat pats
   196     val rem_pats = filter disjoint_to_pat pats
   197   in
   198     case overlap_pats of
   199       [] => pat :: restricted_merge (pats', rem_pats)
   200     | [pat'] => if subset (op =) (pat, pat') then
   201         pat' :: restricted_merge (pats', rem_pats)
   202       else if subset (op =) (pat', pat) then
   203         pat :: restricted_merge (pats', rem_pats)
   204       else error "restricted merge: two patterns require relational join"
   205     | _ => error "restricted merge: multiple patterns overlap"
   206   end;
   207   
   208 fun map_atoms f (Atom a) = Atom (f a)
   209   | map_atoms f (Un (fm1, fm2)) = Un (apply2 (map_atoms f) (fm1, fm2))
   210   | map_atoms f (Int (fm1, fm2)) = Int (apply2 (map_atoms f) (fm1, fm2))
   211 
   212 fun extend Ts bs t = foldr1 mk_sigma (t :: map (fn b => HOLogic.mk_UNIV (nth Ts b)) bs)
   213 
   214 fun rearrange vs (pat, pat') t =
   215   let
   216     val subst = map_index (fn (i, b) => (b, i)) (rev pat)
   217     val vs' = map (nth (rev vs)) pat
   218     val Ts' = map snd (rev vs')
   219     val bs = map (fn b => the (AList.lookup (op =) subst b)) pat'
   220     val rt = term_of_pattern Ts' (Pattern bs)
   221     val rT = type_of_pattern Ts' (Pattern bs)
   222     val (_, f) = mk_split rT vs' rt
   223   in
   224     mk_image f t
   225   end;
   226 
   227 fun adjust vs pats (Pattern pat, t) =
   228   let
   229     val SOME p = find_first (fn p => not (null (inter (op =) pat p))) pats
   230     val missing = subtract (op =) pat p
   231     val Ts = rev (map snd vs)
   232     val t' = extend Ts missing t
   233   in (Pattern p, rearrange vs (pat @ missing, p) t') end
   234 
   235 fun adjust_atoms vs pats fm = map_atoms (adjust vs pats) fm
   236 
   237 fun merge_inter vs (pats1, fm1) (pats2, fm2) =
   238   let
   239     val pats = restricted_merge (map dest_Pattern pats1, map dest_Pattern pats2) 
   240     val (fm1', fm2') = apply2 (adjust_atoms vs pats) (fm1, fm2)
   241   in
   242     (map Pattern pats, Int (fm1', fm2'))
   243   end;
   244 
   245 fun merge_union vs (pats1, fm1) (pats2, fm2) = 
   246   let
   247     val pats = merge (map dest_Pattern pats1, map dest_Pattern pats2)
   248     val (fm1', fm2') = apply2 (adjust_atoms vs pats) (fm1, fm2)
   249   in
   250     (map Pattern pats, Un (fm1', fm2'))
   251   end;
   252 
   253 fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge_inter vs (mk_formula vs t1) (mk_formula vs t2)
   254   | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge_union vs (mk_formula vs t1) (mk_formula vs t2)
   255   | mk_formula vs t = apfst single (mk_atom vs t)
   256 
   257 fun strip_Int (Int (fm1, fm2)) = fm1 :: (strip_Int fm2) 
   258   | strip_Int fm = [fm]
   259 
   260 (* term construction *)
   261 
   262 fun reorder_bounds pats t =
   263   let
   264     val bounds = maps dest_Pattern pats
   265     val bperm = bounds ~~ ((length bounds - 1) downto 0)
   266       |> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
   267   in
   268     subst_bounds (map Bound bperm, t)
   269   end;
   270 
   271 fun is_reordering t =
   272   let val (t', _, _) = HOLogic.strip_psplits t
   273   in forall (fn Bound _ => true) (HOLogic.strip_tuple t') end
   274 
   275 fun mk_pointfree_expr t =
   276   let
   277     val ((x, T), (vs, t'')) = apsnd strip_ex (dest_Collect t)
   278     val Ts = map snd (rev vs)
   279     fun mk_mem_UNIV n = HOLogic.mk_mem (Bound n, HOLogic.mk_UNIV (nth Ts n))
   280     fun lookup (pat', t) pat = if pat = pat' then t else HOLogic.mk_UNIV (type_of_pattern Ts pat)
   281     val conjs = HOLogic.dest_conj t''
   282     val refl = HOLogic.eq_const T $ Bound (length vs) $ Bound (length vs)
   283     val is_the_eq =
   284       the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
   285     val eq = the_default refl (find_first is_the_eq conjs)
   286     val f = snd (HOLogic.dest_eq eq)
   287     val conjs' = filter_out (fn t => eq = t) conjs
   288     val unused_bounds = subtract (op =) (distinct (op =) (maps loose_bnos conjs'))
   289       (0 upto (length vs - 1))
   290     val (pats, fm) =
   291       mk_formula ((x, T) :: vs) (foldr1 HOLogic.mk_conj (conjs' @ map mk_mem_UNIV unused_bounds))
   292     fun mk_set (Atom pt) = foldr1 mk_sigma (map (lookup pt) pats)
   293       | mk_set (Un (f1, f2)) = mk_sup (mk_set f1, mk_set f2)
   294       | mk_set (Int (f1, f2)) = mk_inf (mk_set f1, mk_set f2)
   295     val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
   296     val t = mk_split_abs (rev ((x, T) :: vs)) pat (reorder_bounds pats f)
   297   in
   298     if the_default false (try is_reordering t) andalso is_collect_atom fm then
   299       error "mk_pointfree_expr: trivial case" 
   300     else (fm, mk_image t (mk_set fm))
   301   end;
   302 
   303 val rewrite_term = try mk_pointfree_expr
   304 
   305 
   306 (* proof tactic *)
   307 
   308 val case_prod_distrib = @{lemma "(case_prod g x) z = case_prod (% x y. (g x y) z) x" by (simp add: case_prod_beta)}
   309 
   310 val vimageI2' = @{lemma "f a \<notin> A ==> a \<notin> f -` A" by simp}
   311 val vimageE' =
   312   @{lemma "a \<notin> f -` B ==> (\<And> x. f a = x ==> x \<notin> B ==> P) ==> P" by simp}
   313 
   314 val collectI' = @{lemma "\<not> P a ==> a \<notin> {x. P x}" by auto}
   315 val collectE' = @{lemma "a \<notin> {x. P x} ==> (\<not> P a ==> Q) ==> Q" by auto}
   316 
   317 fun elim_Collect_tac ctxt =
   318   dresolve_tac ctxt @{thms iffD1 [OF mem_Collect_eq]}
   319   THEN' (REPEAT_DETERM o (eresolve_tac ctxt @{thms exE}))
   320   THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms conjE}
   321   THEN' TRY o hyp_subst_tac ctxt;
   322 
   323 fun intro_image_tac ctxt =
   324   resolve_tac ctxt @{thms image_eqI}
   325   THEN' (REPEAT_DETERM1 o
   326       (resolve_tac ctxt @{thms refl}
   327       ORELSE' resolve_tac ctxt @{thms arg_cong2 [OF refl, where f = "op =", OF prod.case, THEN iffD2]}
   328       ORELSE' CONVERSION (Conv.params_conv ~1 (K (Conv.concl_conv ~1
   329         (HOLogic.Trueprop_conv
   330           (HOLogic.eq_conv Conv.all_conv (Conv.rewr_conv (mk_meta_eq case_prod_distrib)))))) ctxt)))
   331 
   332 fun elim_image_tac ctxt =
   333   eresolve_tac ctxt @{thms imageE}
   334   THEN' REPEAT_DETERM o CHANGED o
   335     (TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms split_paired_all prod.case})
   336     THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject}
   337     THEN' TRY o hyp_subst_tac ctxt)
   338 
   339 fun tac1_of_formula ctxt (Int (fm1, fm2)) =
   340     TRY o eresolve_tac ctxt @{thms conjE}
   341     THEN' resolve_tac ctxt @{thms IntI}
   342     THEN' (fn i => tac1_of_formula ctxt fm2 (i + 1))
   343     THEN' tac1_of_formula ctxt fm1
   344   | tac1_of_formula ctxt (Un (fm1, fm2)) =
   345     eresolve_tac ctxt @{thms disjE} THEN' resolve_tac ctxt @{thms UnI1}
   346     THEN' tac1_of_formula ctxt fm1
   347     THEN' resolve_tac ctxt @{thms UnI2}
   348     THEN' tac1_of_formula ctxt fm2
   349   | tac1_of_formula ctxt (Atom _) =
   350     REPEAT_DETERM1 o (assume_tac ctxt
   351       ORELSE' resolve_tac ctxt @{thms SigmaI}
   352       ORELSE' ((resolve_tac ctxt @{thms CollectI} ORELSE' resolve_tac ctxt [collectI']) THEN'
   353         TRY o simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}]))
   354       ORELSE' ((resolve_tac ctxt @{thms vimageI2} ORELSE' resolve_tac ctxt [vimageI2']) THEN'
   355         TRY o simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}]))
   356       ORELSE' (resolve_tac ctxt @{thms image_eqI} THEN'
   357     (REPEAT_DETERM o
   358       (resolve_tac ctxt @{thms refl}
   359       ORELSE' resolve_tac ctxt @{thms arg_cong2[OF refl, where f = "op =", OF prod.case, THEN iffD2]})))
   360       ORELSE' resolve_tac ctxt @{thms UNIV_I}
   361       ORELSE' resolve_tac ctxt @{thms iffD2[OF Compl_iff]}
   362       ORELSE' assume_tac ctxt)
   363 
   364 fun tac2_of_formula ctxt (Int (fm1, fm2)) =
   365     TRY o eresolve_tac ctxt @{thms IntE}
   366     THEN' TRY o resolve_tac ctxt @{thms conjI}
   367     THEN' (fn i => tac2_of_formula ctxt fm2 (i + 1))
   368     THEN' tac2_of_formula ctxt fm1
   369   | tac2_of_formula ctxt (Un (fm1, fm2)) =
   370     eresolve_tac ctxt @{thms UnE} THEN' resolve_tac ctxt @{thms disjI1}
   371     THEN' tac2_of_formula ctxt fm1
   372     THEN' resolve_tac ctxt @{thms disjI2}
   373     THEN' tac2_of_formula ctxt fm2
   374   | tac2_of_formula ctxt (Atom _) =
   375     REPEAT_DETERM o
   376       (assume_tac ctxt
   377        ORELSE' dresolve_tac ctxt @{thms iffD1[OF mem_Sigma_iff]}
   378        ORELSE' eresolve_tac ctxt @{thms conjE}
   379        ORELSE' ((eresolve_tac ctxt @{thms CollectE} ORELSE' eresolve_tac ctxt [collectE']) THEN'
   380          TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}]) THEN'
   381          REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject} THEN' TRY o hyp_subst_tac ctxt THEN'
   382          TRY o resolve_tac ctxt @{thms refl})
   383        ORELSE' (eresolve_tac ctxt @{thms imageE}
   384          THEN' (REPEAT_DETERM o CHANGED o
   385          (TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps @{thms split_paired_all prod.case})
   386          THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject}
   387          THEN' TRY o hyp_subst_tac ctxt THEN' TRY o resolve_tac ctxt @{thms refl})))
   388        ORELSE' eresolve_tac ctxt @{thms ComplE}
   389        ORELSE' ((eresolve_tac ctxt @{thms vimageE} ORELSE' eresolve_tac ctxt [vimageE'])
   390         THEN' TRY o full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps [@{thm prod.case}])
   391         THEN' TRY o hyp_subst_tac ctxt THEN' TRY o resolve_tac ctxt @{thms refl}))
   392 
   393 fun tac ctxt fm =
   394   let
   395     val subset_tac1 = resolve_tac ctxt @{thms subsetI}
   396       THEN' elim_Collect_tac ctxt
   397       THEN' intro_image_tac ctxt
   398       THEN' tac1_of_formula ctxt fm
   399     val subset_tac2 = resolve_tac ctxt @{thms subsetI}
   400       THEN' elim_image_tac ctxt
   401       THEN' resolve_tac ctxt @{thms iffD2[OF mem_Collect_eq]}
   402       THEN' REPEAT_DETERM o resolve_tac ctxt @{thms exI}
   403       THEN' (TRY o REPEAT_ALL_NEW (resolve_tac ctxt @{thms conjI}))
   404       THEN' (K (TRY (FIRSTGOAL ((TRY o hyp_subst_tac ctxt) THEN' resolve_tac ctxt @{thms refl}))))
   405       THEN' (fn i => EVERY (rev (map_index (fn (j, f) =>
   406         REPEAT_DETERM (eresolve_tac ctxt @{thms IntE} (i + j)) THEN
   407         tac2_of_formula ctxt f (i + j)) (strip_Int fm))))
   408   in
   409     resolve_tac ctxt @{thms subset_antisym} THEN' subset_tac1 THEN' subset_tac2
   410   end;
   411 
   412 
   413 (* preprocessing conversion:
   414   rewrites {(x1, ..., xn). P x1 ... xn} to {(x1, ..., xn) | x1 ... xn. P x1 ... xn} *)
   415 
   416 fun comprehension_conv ctxt ct =
   417   let
   418     fun dest_Collect (Const (@{const_name Collect}, T) $ t) = (HOLogic.dest_setT (body_type T), t)
   419       | dest_Collect t = raise TERM ("dest_Collect", [t])
   420     fun list_ex vs t = fold_rev (fn (x, T) => fn t => HOLogic.exists_const T $ Abs (x, T, t)) vs t
   421     fun mk_term t =
   422       let
   423         val (T, t') = dest_Collect t
   424         val (t'', vs, fp) = case strip_psplits t' of
   425             (_, [_], _) => raise TERM("mk_term", [t'])
   426           | (t'', vs, fp) => (t'', vs, fp)
   427         val Ts = map snd vs
   428         val eq = HOLogic.eq_const T $ Bound (length Ts) $
   429           (HOLogic.mk_ptuple fp (HOLogic.mk_ptupleT fp Ts) (rev (map_index (fn (i, _) => Bound i) Ts)))
   430       in
   431         HOLogic.Collect_const T $ absdummy T (list_ex vs (HOLogic.mk_conj (eq, t'')))
   432       end;
   433     fun is_eq th = is_some (try (HOLogic.dest_eq o HOLogic.dest_Trueprop) (Thm.prop_of th))
   434     val unfold_thms = @{thms split_paired_all mem_Collect_eq prod.case}
   435     fun tac ctxt = 
   436       resolve_tac ctxt @{thms set_eqI}
   437       THEN' simp_tac (put_simpset HOL_basic_ss ctxt addsimps unfold_thms)
   438       THEN' resolve_tac ctxt @{thms iffI}
   439       THEN' REPEAT_DETERM o resolve_tac ctxt @{thms exI}
   440       THEN' resolve_tac ctxt @{thms conjI} THEN' resolve_tac ctxt @{thms refl} THEN' assume_tac ctxt
   441       THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms exE}
   442       THEN' eresolve_tac ctxt @{thms conjE}
   443       THEN' REPEAT_DETERM o eresolve_tac ctxt @{thms Pair_inject}
   444       THEN' Subgoal.FOCUS (fn {prems, context = ctxt', ...} =>
   445         simp_tac (put_simpset HOL_basic_ss ctxt' addsimps (filter is_eq prems)) 1) ctxt
   446       THEN' TRY o assume_tac ctxt
   447   in
   448     case try mk_term (Thm.term_of ct) of
   449       NONE => Thm.reflexive ct
   450     | SOME t' =>
   451       Goal.prove ctxt [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq (Thm.term_of ct, t')))
   452           (fn {context, ...} => tac context 1)
   453         RS @{thm eq_reflection}
   454   end
   455 
   456 
   457 (* main simprocs *)
   458 
   459 val prep_thms =
   460   map mk_meta_eq ([@{thm Bex_def}, @{thm Pow_iff[symmetric]}] @ @{thms ex_simps[symmetric]})
   461 
   462 val post_thms =
   463   map mk_meta_eq [@{thm Times_Un_distrib1[symmetric]},
   464   @{lemma "A \<times> B \<union> A \<times> C = A \<times> (B \<union> C)" by auto},
   465   @{lemma "(A \<times> B \<inter> C \<times> D) = (A \<inter> C) \<times> (B \<inter> D)" by auto}]
   466 
   467 fun conv ctxt t =
   468   let
   469     val ([t'], ctxt') = Variable.import_terms true [t] (Variable.declare_term t ctxt)
   470     val ct = Thm.cterm_of ctxt' t'
   471     fun unfold_conv thms =
   472       Raw_Simplifier.rewrite_cterm (false, false, false) (K (K NONE))
   473         (empty_simpset ctxt' addsimps thms)
   474     val prep_eq = (comprehension_conv ctxt' then_conv unfold_conv prep_thms) ct
   475     val t'' = Thm.term_of (Thm.rhs_of prep_eq)
   476     fun mk_thm (fm, t''') = Goal.prove ctxt' [] []
   477       (HOLogic.mk_Trueprop (HOLogic.mk_eq (t'', t'''))) (fn {context, ...} => tac context fm 1)
   478     fun unfold th = th RS ((prep_eq RS meta_eq_to_obj_eq) RS @{thm trans})
   479     val post =
   480       Conv.fconv_rule
   481         (HOLogic.Trueprop_conv (HOLogic.eq_conv Conv.all_conv (unfold_conv post_thms)))
   482     val export = singleton (Variable.export ctxt' ctxt)
   483   in
   484     Option.map (export o post o unfold o mk_thm) (rewrite_term t'')
   485   end;
   486 
   487 fun base_simproc ctxt redex =
   488   let
   489     val set_compr = Thm.term_of redex
   490   in
   491     conv ctxt set_compr
   492     |> Option.map (fn thm => thm RS @{thm eq_reflection})
   493   end;
   494 
   495 fun instantiate_arg_cong ctxt pred =
   496   let
   497     val arg_cong = Thm.incr_indexes (maxidx_of_term pred + 1) @{thm arg_cong}
   498     val f $ _ = fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (Thm.concl_of arg_cong)))
   499   in
   500     cterm_instantiate [(Thm.cterm_of ctxt f, Thm.cterm_of ctxt pred)] arg_cong
   501   end;
   502 
   503 fun simproc ctxt redex =
   504   let
   505     val pred $ set_compr = Thm.term_of redex
   506     val arg_cong' = instantiate_arg_cong ctxt pred
   507   in
   508     conv ctxt set_compr
   509     |> Option.map (fn thm => thm RS arg_cong' RS @{thm eq_reflection})
   510   end;
   511 
   512 fun code_simproc ctxt redex =
   513   let
   514     fun unfold_conv thms =
   515       Raw_Simplifier.rewrite_cterm (false, false, false) (K (K NONE))
   516         (empty_simpset ctxt addsimps thms)
   517     val prep_thm = unfold_conv @{thms eq_equal[symmetric]} redex
   518   in
   519     case base_simproc ctxt (Thm.rhs_of prep_thm) of
   520       SOME rewr_thm => SOME (transitive_thm OF [transitive_thm OF [prep_thm, rewr_thm],
   521         unfold_conv @{thms eq_equal} (Thm.rhs_of rewr_thm)])
   522     | NONE => NONE
   523   end;
   524 
   525 end;