# HG changeset patch # User bulwahn # Date 1352372387 -3600 # Node ID 19965e6a705e26bb8462809e5a1b45aab54cb543 # Parent b7265db3a1dc7b116de7505726193903ba3283ab handling arbitrary terms in the set comprehension and more general merging of patterns possible in the set_comprehension_pointfree simproc diff -r b7265db3a1dc -r 19965e6a705e src/HOL/Tools/set_comprehension_pointfree.ML --- a/src/HOL/Tools/set_comprehension_pointfree.ML Thu Nov 08 11:59:46 2012 +0100 +++ b/src/HOL/Tools/set_comprehension_pointfree.ML Thu Nov 08 11:59:47 2012 +0100 @@ -130,6 +130,17 @@ fun map_atom f (Atom a) = Atom (f a) | map_atom _ x = x +fun is_collect_atom (Atom (_, Const(@{const_name Collect}, _) $ _)) = true + | is_collect_atom (Atom (_, Const (@{const_name "Groups.uminus_class.uminus"}, _) $ (Const(@{const_name Collect}, _) $ _))) = true + | is_collect_atom _ = false + +fun mk_split _ [(x, T)] t = (T, Abs (x, T, t)) + | mk_split rT ((x, T) :: vs) t = + let + val (T', t') = mk_split rT vs t + val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t')) + in (domain_type (fastype_of t''), t'') end + fun mk_atom vs (Const (@{const_name "Set.member"}, _) $ x $ s) = if not (null (loose_bnos s)) then raise TERM ("mk_atom: bound variables in the set expression", [s]) @@ -140,37 +151,112 @@ let val bs = loose_bnos x val vs' = map (nth (rev vs)) bs - val x' = subst_atomic (map_index (fn (i, j) => (Bound j, Bound i)) (rev bs)) x + val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs) + |> sort (fn (p1, p2) => int_ord (fst p1, fst p2)) + |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst'))))) + val x' = subst_bounds (subst, x) val tuple = Pattern bs val rT = HOLogic.dest_setT (fastype_of s) - fun mk_split [(x, T)] t = (T, Abs (x, T, t)) - | mk_split ((x, T) :: vs) t = - let - val (T', t') = mk_split vs t - val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t')) - in (domain_type (fastype_of t''), t'') end - val (_, f) = mk_split vs' x' + val (_, f) = mk_split rT vs' x' in (tuple, Atom (tuple, mk_vimage f s)) end) | mk_atom vs (Const (@{const_name "HOL.Not"}, _) $ t) = apsnd (map_atom (apsnd mk_Compl)) (mk_atom vs t) + | mk_atom vs t = + let + val bs = loose_bnos t + val vs' = map (nth (rev vs)) bs + val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs) + |> sort (fn (p1, p2) => int_ord (fst p1, fst p2)) + |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst'))))) + val t' = subst_bounds (subst, t) + val tuple = Pattern bs + val setT = HOLogic.mk_tupleT (map snd vs') + val (_, s) = mk_split @{typ bool} vs' t' + in + (tuple, Atom (tuple, HOLogic.Collect_const setT $ s)) + end -fun can_merge (pats1, pats2) = +fun merge' [] (pats1, pats2) = ([], (pats1, pats2)) + | merge' pat (pats, []) = (pat, (pats, [])) + | merge' pat (pats1, pats) = let - fun check (Pattern pat1) (Pattern pat2) = (pat1 = pat2) - orelse (null (inter (op =) pat1 pat2)) + fun disjoint_to_pat p = null (inter (op =) pat p) + val overlap_pats = filter_out disjoint_to_pat pats + val rem_pats = filter disjoint_to_pat pats + val (pat, (pats', pats1')) = merge' (distinct (op =) (flat overlap_pats @ pat)) (rem_pats, pats1) in - forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1 + (pat, (pats1', pats')) end -fun merge_patterns (pats1, pats2) = - if can_merge (pats1, pats2) then - union (op =) pats1 pats2 - else raise Fail "merge_patterns: variable groups overlap" +fun merge ([], pats) = pats + | merge (pat :: pats', pats) = + let val (pat', (pats1', pats2')) = merge' pat (pats', pats) + in pat' :: merge (pats1', pats2') end; + +fun restricted_merge ([], pats) = pats + | restricted_merge (pat :: pats', pats) = + let + fun disjoint_to_pat p = null (inter (op =) pat p) + val overlap_pats = filter_out disjoint_to_pat pats + val rem_pats = filter disjoint_to_pat pats + in + case overlap_pats of + [] => pat :: restricted_merge (pats', rem_pats) + | [pat'] => if subset (op =) (pat, pat') then + pat' :: restricted_merge (pats', rem_pats) + else if subset (op =) (pat', pat) then + pat :: restricted_merge (pats', rem_pats) + else error "restricted merge: two patterns require relational join" + | _ => error "restricted merge: multiple patterns overlap" + end; + +fun map_atoms f (Atom a) = Atom (f a) + | map_atoms f (Un (fm1, fm2)) = Un (pairself (map_atoms f) (fm1, fm2)) + | map_atoms f (Int (fm1, fm2)) = Int (pairself (map_atoms f) (fm1, fm2)) + +fun extend Ts bs t = fold (fn b => fn t => mk_sigma (t, HOLogic.mk_UNIV (nth Ts b))) bs t -fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2)) +fun rearrange vs (pat, pat') t = + let + val subst = map_index (fn (i, b) => (b, i)) (rev pat) + val vs' = map (nth (rev vs)) pat + val Ts' = map snd (rev vs') + val bs = map (fn b => the (AList.lookup (op =) subst b)) pat' + val rt = term_of_pattern Ts' (Pattern bs) + val rT = type_of_pattern Ts' (Pattern bs) + val (_, f) = mk_split rT vs' rt + in + mk_image f t + end; + +fun adjust vs pats (Pattern pat, t) = + let + val SOME p = find_first (fn p => not (null (inter (op =) pat p))) pats + val missing = subtract (op =) pat p + val Ts = rev (map snd vs) + val t' = extend Ts missing t + in (Pattern p, rearrange vs (pat @ missing, p) t') end -fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula vs t1) (mk_formula vs t2) - | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula vs t1) (mk_formula vs t2) +fun adjust_atoms vs pats fm = map_atoms (adjust vs pats) fm + +fun merge_inter vs (pats1, fm1) (pats2, fm2) = + let + val pats = restricted_merge (map dest_Pattern pats1, map dest_Pattern pats2) + val (fm1', fm2') = pairself (adjust_atoms vs pats) (fm1, fm2) + in + (map Pattern pats, Int (fm1', fm2')) + end; + +fun merge_union vs (pats1, fm1) (pats2, fm2) = + let + val pats = merge (map dest_Pattern pats1, map dest_Pattern pats2) + val (fm1', fm2') = pairself (adjust_atoms vs pats) (fm1, fm2) + in + (map Pattern pats, Un (fm1', fm2')) + end; + +fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge_inter vs (mk_formula vs t1) (mk_formula vs t2) + | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge_union vs (mk_formula vs t1) (mk_formula vs t2) | mk_formula vs t = apfst single (mk_atom vs t) fun strip_Int (Int (fm1, fm2)) = fm1 :: (strip_Int fm2) @@ -187,6 +273,10 @@ subst_bounds (map Bound bperm, t) end; +fun is_reordering t = + let val (t', _, _) = HOLogic.strip_psplits t + in forall (fn Bound _ => true) (HOLogic.strip_tuple t') end + fun mk_pointfree_expr t = let val ((x, T), (vs, t'')) = apsnd strip_ex (dest_Collect t) @@ -210,7 +300,9 @@ val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats) val t = mk_split_abs (rev ((x, T) :: vs)) pat (reorder_bounds pats f) in - (fm, mk_image t (mk_set fm)) + if the_default false (try is_reordering t) andalso is_collect_atom fm then + error "mk_pointfree_expr: trivial case" + else (fm, mk_image t (mk_set fm)) end; val rewrite_term = try mk_pointfree_expr @@ -236,6 +328,9 @@ val vimageE' = @{lemma "a \ f -` B ==> (\ x. f a = x ==> x \ B ==> P) ==> P" by simp} +val collectI' = @{lemma "\ P a ==> a \ {x. P x}" by auto} +val collectE' = @{lemma "a \ {x. P x} ==> (\ P a ==> Q) ==> Q" by auto} + val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]} THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE})) THEN' REPEAT_DETERM o etac @{thm conjE} @@ -265,8 +360,13 @@ THEN' tac1_of_formula fm2 | tac1_of_formula (Atom _) = REPEAT_DETERM1 o (rtac @{thm SigmaI} + ORELSE' ((rtac @{thm CollectI} ORELSE' rtac collectI') THEN' TRY o Simplifier.simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}])) ORELSE' ((rtac @{thm vimageI2} ORELSE' rtac vimageI2') THEN' TRY o Simplifier.simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}])) + ORELSE' (rtac @{thm image_eqI} THEN' + (REPEAT_DETERM o + (rtac @{thm refl} + ORELSE' rtac @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}))) ORELSE' rtac @{thm UNIV_I} ORELSE' rtac @{thm iffD2[OF Compl_iff]} ORELSE' atac) @@ -282,13 +382,21 @@ THEN' rtac @{thm disjI2} THEN' tac2_of_formula fm2 | tac2_of_formula (Atom _) = - TRY o REPEAT_DETERM1 o + REPEAT_DETERM o (dtac @{thm iffD1[OF mem_Sigma_iff]} + ORELSE' ((etac @{thm CollectE} ORELSE' etac collectE') THEN' + TRY o Simplifier.full_simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]) THEN' + REPEAT_DETERM o etac @{thm Pair_inject} THEN' TRY o hyp_subst_tac THEN' TRY o rtac @{thm refl}) + ORELSE' (REPEAT_DETERM1 o etac @{thm imageE} + THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm}) + THEN' REPEAT_DETERM o etac @{thm Pair_inject} + THEN' TRY o hyp_subst_tac) ORELSE' etac @{thm conjE} ORELSE' etac @{thm ComplE} ORELSE' ((etac @{thm vimageE} ORELSE' etac vimageE') THEN' TRY o Simplifier.full_simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]) THEN' TRY o hyp_subst_tac) + ORELSE' (REPEAT_DETERM1 o etac @{thm Pair_inject} THEN' TRY o hyp_subst_tac) ORELSE' atac) fun tac ctxt fm = @@ -302,7 +410,7 @@ THEN' rtac @{thm iffD2[OF mem_Collect_eq]} THEN' REPEAT_DETERM o resolve_tac @{thms exI} THEN' (TRY o REPEAT_ALL_NEW (rtac @{thm conjI})) - THEN' (K (TRY (SOMEGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl})))) + THEN' (K (TRY (FIRSTGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl})))) THEN' (fn i => EVERY (rev (map_index (fn (j, f) => REPEAT_DETERM (etac @{thm IntE} (i + j)) THEN tac2_of_formula f (i + j)) (strip_Int fm)))) in