--- a/src/HOL/Tools/set_comprehension_pointfree.ML Thu Nov 08 11:59:46 2012 +0100
+++ b/src/HOL/Tools/set_comprehension_pointfree.ML Thu Nov 08 11:59:47 2012 +0100
@@ -130,6 +130,17 @@
fun map_atom f (Atom a) = Atom (f a)
| map_atom _ x = x
+fun is_collect_atom (Atom (_, Const(@{const_name Collect}, _) $ _)) = true
+ | is_collect_atom (Atom (_, Const (@{const_name "Groups.uminus_class.uminus"}, _) $ (Const(@{const_name Collect}, _) $ _))) = true
+ | is_collect_atom _ = false
+
+fun mk_split _ [(x, T)] t = (T, Abs (x, T, t))
+ | mk_split rT ((x, T) :: vs) t =
+ let
+ val (T', t') = mk_split rT vs t
+ val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t'))
+ in (domain_type (fastype_of t''), t'') end
+
fun mk_atom vs (Const (@{const_name "Set.member"}, _) $ x $ s) =
if not (null (loose_bnos s)) then
raise TERM ("mk_atom: bound variables in the set expression", [s])
@@ -140,37 +151,112 @@
let
val bs = loose_bnos x
val vs' = map (nth (rev vs)) bs
- val x' = subst_atomic (map_index (fn (i, j) => (Bound j, Bound i)) (rev bs)) x
+ val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs)
+ |> sort (fn (p1, p2) => int_ord (fst p1, fst p2))
+ |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst')))))
+ val x' = subst_bounds (subst, x)
val tuple = Pattern bs
val rT = HOLogic.dest_setT (fastype_of s)
- fun mk_split [(x, T)] t = (T, Abs (x, T, t))
- | mk_split ((x, T) :: vs) t =
- let
- val (T', t') = mk_split vs t
- val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t'))
- in (domain_type (fastype_of t''), t'') end
- val (_, f) = mk_split vs' x'
+ val (_, f) = mk_split rT vs' x'
in (tuple, Atom (tuple, mk_vimage f s)) end)
| mk_atom vs (Const (@{const_name "HOL.Not"}, _) $ t) =
apsnd (map_atom (apsnd mk_Compl)) (mk_atom vs t)
+ | mk_atom vs t =
+ let
+ val bs = loose_bnos t
+ val vs' = map (nth (rev vs)) bs
+ val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs)
+ |> sort (fn (p1, p2) => int_ord (fst p1, fst p2))
+ |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst')))))
+ val t' = subst_bounds (subst, t)
+ val tuple = Pattern bs
+ val setT = HOLogic.mk_tupleT (map snd vs')
+ val (_, s) = mk_split @{typ bool} vs' t'
+ in
+ (tuple, Atom (tuple, HOLogic.Collect_const setT $ s))
+ end
-fun can_merge (pats1, pats2) =
+fun merge' [] (pats1, pats2) = ([], (pats1, pats2))
+ | merge' pat (pats, []) = (pat, (pats, []))
+ | merge' pat (pats1, pats) =
let
- fun check (Pattern pat1) (Pattern pat2) = (pat1 = pat2)
- orelse (null (inter (op =) pat1 pat2))
+ fun disjoint_to_pat p = null (inter (op =) pat p)
+ val overlap_pats = filter_out disjoint_to_pat pats
+ val rem_pats = filter disjoint_to_pat pats
+ val (pat, (pats', pats1')) = merge' (distinct (op =) (flat overlap_pats @ pat)) (rem_pats, pats1)
in
- forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1
+ (pat, (pats1', pats'))
end
-fun merge_patterns (pats1, pats2) =
- if can_merge (pats1, pats2) then
- union (op =) pats1 pats2
- else raise Fail "merge_patterns: variable groups overlap"
+fun merge ([], pats) = pats
+ | merge (pat :: pats', pats) =
+ let val (pat', (pats1', pats2')) = merge' pat (pats', pats)
+ in pat' :: merge (pats1', pats2') end;
+
+fun restricted_merge ([], pats) = pats
+ | restricted_merge (pat :: pats', pats) =
+ let
+ fun disjoint_to_pat p = null (inter (op =) pat p)
+ val overlap_pats = filter_out disjoint_to_pat pats
+ val rem_pats = filter disjoint_to_pat pats
+ in
+ case overlap_pats of
+ [] => pat :: restricted_merge (pats', rem_pats)
+ | [pat'] => if subset (op =) (pat, pat') then
+ pat' :: restricted_merge (pats', rem_pats)
+ else if subset (op =) (pat', pat) then
+ pat :: restricted_merge (pats', rem_pats)
+ else error "restricted merge: two patterns require relational join"
+ | _ => error "restricted merge: multiple patterns overlap"
+ end;
+
+fun map_atoms f (Atom a) = Atom (f a)
+ | map_atoms f (Un (fm1, fm2)) = Un (pairself (map_atoms f) (fm1, fm2))
+ | map_atoms f (Int (fm1, fm2)) = Int (pairself (map_atoms f) (fm1, fm2))
+
+fun extend Ts bs t = fold (fn b => fn t => mk_sigma (t, HOLogic.mk_UNIV (nth Ts b))) bs t
-fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2))
+fun rearrange vs (pat, pat') t =
+ let
+ val subst = map_index (fn (i, b) => (b, i)) (rev pat)
+ val vs' = map (nth (rev vs)) pat
+ val Ts' = map snd (rev vs')
+ val bs = map (fn b => the (AList.lookup (op =) subst b)) pat'
+ val rt = term_of_pattern Ts' (Pattern bs)
+ val rT = type_of_pattern Ts' (Pattern bs)
+ val (_, f) = mk_split rT vs' rt
+ in
+ mk_image f t
+ end;
+
+fun adjust vs pats (Pattern pat, t) =
+ let
+ val SOME p = find_first (fn p => not (null (inter (op =) pat p))) pats
+ val missing = subtract (op =) pat p
+ val Ts = rev (map snd vs)
+ val t' = extend Ts missing t
+ in (Pattern p, rearrange vs (pat @ missing, p) t') end
-fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula vs t1) (mk_formula vs t2)
- | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula vs t1) (mk_formula vs t2)
+fun adjust_atoms vs pats fm = map_atoms (adjust vs pats) fm
+
+fun merge_inter vs (pats1, fm1) (pats2, fm2) =
+ let
+ val pats = restricted_merge (map dest_Pattern pats1, map dest_Pattern pats2)
+ val (fm1', fm2') = pairself (adjust_atoms vs pats) (fm1, fm2)
+ in
+ (map Pattern pats, Int (fm1', fm2'))
+ end;
+
+fun merge_union vs (pats1, fm1) (pats2, fm2) =
+ let
+ val pats = merge (map dest_Pattern pats1, map dest_Pattern pats2)
+ val (fm1', fm2') = pairself (adjust_atoms vs pats) (fm1, fm2)
+ in
+ (map Pattern pats, Un (fm1', fm2'))
+ end;
+
+fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge_inter vs (mk_formula vs t1) (mk_formula vs t2)
+ | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge_union vs (mk_formula vs t1) (mk_formula vs t2)
| mk_formula vs t = apfst single (mk_atom vs t)
fun strip_Int (Int (fm1, fm2)) = fm1 :: (strip_Int fm2)
@@ -187,6 +273,10 @@
subst_bounds (map Bound bperm, t)
end;
+fun is_reordering t =
+ let val (t', _, _) = HOLogic.strip_psplits t
+ in forall (fn Bound _ => true) (HOLogic.strip_tuple t') end
+
fun mk_pointfree_expr t =
let
val ((x, T), (vs, t'')) = apsnd strip_ex (dest_Collect t)
@@ -210,7 +300,9 @@
val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
val t = mk_split_abs (rev ((x, T) :: vs)) pat (reorder_bounds pats f)
in
- (fm, mk_image t (mk_set fm))
+ if the_default false (try is_reordering t) andalso is_collect_atom fm then
+ error "mk_pointfree_expr: trivial case"
+ else (fm, mk_image t (mk_set fm))
end;
val rewrite_term = try mk_pointfree_expr
@@ -236,6 +328,9 @@
val vimageE' =
@{lemma "a \<notin> f -` B ==> (\<And> x. f a = x ==> x \<notin> B ==> P) ==> P" by simp}
+val collectI' = @{lemma "\<not> P a ==> a \<notin> {x. P x}" by auto}
+val collectE' = @{lemma "a \<notin> {x. P x} ==> (\<not> P a ==> Q) ==> Q" by auto}
+
val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE}))
THEN' REPEAT_DETERM o etac @{thm conjE}
@@ -265,8 +360,13 @@
THEN' tac1_of_formula fm2
| tac1_of_formula (Atom _) =
REPEAT_DETERM1 o (rtac @{thm SigmaI}
+ ORELSE' ((rtac @{thm CollectI} ORELSE' rtac collectI') THEN' TRY o Simplifier.simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]))
ORELSE' ((rtac @{thm vimageI2} ORELSE' rtac vimageI2') THEN'
TRY o Simplifier.simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]))
+ ORELSE' (rtac @{thm image_eqI} THEN'
+ (REPEAT_DETERM o
+ (rtac @{thm refl}
+ ORELSE' rtac @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]})))
ORELSE' rtac @{thm UNIV_I}
ORELSE' rtac @{thm iffD2[OF Compl_iff]}
ORELSE' atac)
@@ -282,13 +382,21 @@
THEN' rtac @{thm disjI2}
THEN' tac2_of_formula fm2
| tac2_of_formula (Atom _) =
- TRY o REPEAT_DETERM1 o
+ REPEAT_DETERM o
(dtac @{thm iffD1[OF mem_Sigma_iff]}
+ ORELSE' ((etac @{thm CollectE} ORELSE' etac collectE') THEN'
+ TRY o Simplifier.full_simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]) THEN'
+ REPEAT_DETERM o etac @{thm Pair_inject} THEN' TRY o hyp_subst_tac THEN' TRY o rtac @{thm refl})
+ ORELSE' (REPEAT_DETERM1 o etac @{thm imageE}
+ THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
+ THEN' REPEAT_DETERM o etac @{thm Pair_inject}
+ THEN' TRY o hyp_subst_tac)
ORELSE' etac @{thm conjE}
ORELSE' etac @{thm ComplE}
ORELSE' ((etac @{thm vimageE} ORELSE' etac vimageE')
THEN' TRY o Simplifier.full_simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}])
THEN' TRY o hyp_subst_tac)
+ ORELSE' (REPEAT_DETERM1 o etac @{thm Pair_inject} THEN' TRY o hyp_subst_tac)
ORELSE' atac)
fun tac ctxt fm =
@@ -302,7 +410,7 @@
THEN' rtac @{thm iffD2[OF mem_Collect_eq]}
THEN' REPEAT_DETERM o resolve_tac @{thms exI}
THEN' (TRY o REPEAT_ALL_NEW (rtac @{thm conjI}))
- THEN' (K (TRY (SOMEGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl}))))
+ THEN' (K (TRY (FIRSTGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl}))))
THEN' (fn i => EVERY (rev (map_index (fn (j, f) =>
REPEAT_DETERM (etac @{thm IntE} (i + j)) THEN tac2_of_formula f (i + j)) (strip_Int fm))))
in