src/HOL/Tools/set_comprehension_pointfree.ML
changeset 50025 19965e6a705e
parent 50024 b7265db3a1dc
child 50026 d9871e5ea0e1
--- a/src/HOL/Tools/set_comprehension_pointfree.ML	Thu Nov 08 11:59:46 2012 +0100
+++ b/src/HOL/Tools/set_comprehension_pointfree.ML	Thu Nov 08 11:59:47 2012 +0100
@@ -130,6 +130,17 @@
 fun map_atom f (Atom a) = Atom (f a)
   | map_atom _ x = x
 
+fun is_collect_atom (Atom (_, Const(@{const_name Collect}, _) $ _)) = true
+  | is_collect_atom (Atom (_, Const (@{const_name "Groups.uminus_class.uminus"}, _) $ (Const(@{const_name Collect}, _) $ _))) = true
+  | is_collect_atom _ = false
+
+fun mk_split _ [(x, T)] t = (T, Abs (x, T, t))
+  | mk_split rT ((x, T) :: vs) t =
+    let
+      val (T', t') = mk_split rT vs t
+      val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t'))
+    in (domain_type (fastype_of t''), t'') end
+
 fun mk_atom vs (Const (@{const_name "Set.member"}, _) $ x $ s) =
     if not (null (loose_bnos s)) then
       raise TERM ("mk_atom: bound variables in the set expression", [s])
@@ -140,37 +151,112 @@
         let
           val bs = loose_bnos x
           val vs' = map (nth (rev vs)) bs
-          val x' = subst_atomic (map_index (fn (i, j) => (Bound j, Bound i)) (rev bs)) x
+          val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs)
+            |> sort (fn (p1, p2) => int_ord (fst p1, fst p2))
+            |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst')))))  
+          val x' = subst_bounds (subst, x)
           val tuple = Pattern bs
           val rT = HOLogic.dest_setT (fastype_of s)
-          fun mk_split [(x, T)] t = (T, Abs (x, T, t))
-            | mk_split ((x, T) :: vs) t =
-                let
-                  val (T', t') = mk_split vs t
-                  val t'' = HOLogic.split_const (T, T', rT) $ (Abs (x, T, t'))
-                in (domain_type (fastype_of t''), t'') end
-          val (_, f) = mk_split vs' x'
+          val (_, f) = mk_split rT vs' x'
         in (tuple, Atom (tuple, mk_vimage f s)) end)
   | mk_atom vs (Const (@{const_name "HOL.Not"}, _) $ t) =
       apsnd (map_atom (apsnd mk_Compl)) (mk_atom vs t)
+  | mk_atom vs t =
+    let
+      val bs = loose_bnos t
+      val vs' = map (nth (rev vs)) bs
+      val subst = map_index (fn (i, j) => (j, Bound i)) (rev bs)
+        |> sort (fn (p1, p2) => int_ord (fst p1, fst p2))
+        |> (fn subst' => map (fn i => the_default (Bound i) (AList.lookup (op =) subst' i)) (0 upto (fst (snd (split_last subst')))))
+      val t' = subst_bounds (subst, t)
+      val tuple = Pattern bs
+      val setT = HOLogic.mk_tupleT (map snd vs')
+      val (_, s) = mk_split @{typ bool} vs' t'
+    in
+      (tuple, Atom (tuple, HOLogic.Collect_const setT $ s))
+    end
 
-fun can_merge (pats1, pats2) =
+fun merge' [] (pats1, pats2) = ([], (pats1, pats2))
+  | merge' pat (pats, []) = (pat, (pats, []))
+  | merge' pat (pats1, pats) =
   let
-    fun check (Pattern pat1) (Pattern pat2) = (pat1 = pat2)
-      orelse (null (inter (op =) pat1 pat2))
+    fun disjoint_to_pat p = null (inter (op =) pat p)
+    val overlap_pats = filter_out disjoint_to_pat pats
+    val rem_pats = filter disjoint_to_pat pats
+    val (pat, (pats', pats1')) = merge' (distinct (op =) (flat overlap_pats @ pat)) (rem_pats, pats1)
   in
-    forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1 
+    (pat, (pats1', pats'))
   end
 
-fun merge_patterns (pats1, pats2) =
-  if can_merge (pats1, pats2) then
-    union (op =) pats1 pats2
-  else raise Fail "merge_patterns: variable groups overlap"
+fun merge ([], pats) = pats
+  | merge (pat :: pats', pats) =
+  let val (pat', (pats1', pats2')) = merge' pat (pats', pats)
+  in pat' :: merge (pats1', pats2') end;
+
+fun restricted_merge ([], pats) = pats
+  | restricted_merge (pat :: pats', pats) =
+  let
+    fun disjoint_to_pat p = null (inter (op =) pat p)
+    val overlap_pats = filter_out disjoint_to_pat pats
+    val rem_pats = filter disjoint_to_pat pats
+  in
+    case overlap_pats of
+      [] => pat :: restricted_merge (pats', rem_pats)
+    | [pat'] => if subset (op =) (pat, pat') then
+        pat' :: restricted_merge (pats', rem_pats)
+      else if subset (op =) (pat', pat) then
+        pat :: restricted_merge (pats', rem_pats)
+      else error "restricted merge: two patterns require relational join"
+    | _ => error "restricted merge: multiple patterns overlap"
+  end;
+  
+fun map_atoms f (Atom a) = Atom (f a)
+  | map_atoms f (Un (fm1, fm2)) = Un (pairself (map_atoms f) (fm1, fm2))
+  | map_atoms f (Int (fm1, fm2)) = Int (pairself (map_atoms f) (fm1, fm2))
+
+fun extend Ts bs t = fold (fn b => fn t => mk_sigma (t, HOLogic.mk_UNIV (nth Ts b))) bs t
 
-fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2))
+fun rearrange vs (pat, pat') t =
+  let
+    val subst = map_index (fn (i, b) => (b, i)) (rev pat)
+    val vs' = map (nth (rev vs)) pat
+    val Ts' = map snd (rev vs')
+    val bs = map (fn b => the (AList.lookup (op =) subst b)) pat'
+    val rt = term_of_pattern Ts' (Pattern bs)
+    val rT = type_of_pattern Ts' (Pattern bs)
+    val (_, f) = mk_split rT vs' rt
+  in
+    mk_image f t
+  end;
+
+fun adjust vs pats (Pattern pat, t) =
+  let
+    val SOME p = find_first (fn p => not (null (inter (op =) pat p))) pats
+    val missing = subtract (op =) pat p
+    val Ts = rev (map snd vs)
+    val t' = extend Ts missing t
+  in (Pattern p, rearrange vs (pat @ missing, p) t') end
 
-fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula vs t1) (mk_formula vs t2)
-  | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula vs t1) (mk_formula vs t2)
+fun adjust_atoms vs pats fm = map_atoms (adjust vs pats) fm
+
+fun merge_inter vs (pats1, fm1) (pats2, fm2) =
+  let
+    val pats = restricted_merge (map dest_Pattern pats1, map dest_Pattern pats2) 
+    val (fm1', fm2') = pairself (adjust_atoms vs pats) (fm1, fm2)
+  in
+    (map Pattern pats, Int (fm1', fm2'))
+  end;
+
+fun merge_union vs (pats1, fm1) (pats2, fm2) = 
+  let
+    val pats = merge (map dest_Pattern pats1, map dest_Pattern pats2)
+    val (fm1', fm2') = pairself (adjust_atoms vs pats) (fm1, fm2)
+  in
+    (map Pattern pats, Un (fm1', fm2'))
+  end;
+
+fun mk_formula vs (@{const HOL.conj} $ t1 $ t2) = merge_inter vs (mk_formula vs t1) (mk_formula vs t2)
+  | mk_formula vs (@{const HOL.disj} $ t1 $ t2) = merge_union vs (mk_formula vs t1) (mk_formula vs t2)
   | mk_formula vs t = apfst single (mk_atom vs t)
 
 fun strip_Int (Int (fm1, fm2)) = fm1 :: (strip_Int fm2) 
@@ -187,6 +273,10 @@
     subst_bounds (map Bound bperm, t)
   end;
 
+fun is_reordering t =
+  let val (t', _, _) = HOLogic.strip_psplits t
+  in forall (fn Bound _ => true) (HOLogic.strip_tuple t') end
+
 fun mk_pointfree_expr t =
   let
     val ((x, T), (vs, t'')) = apsnd strip_ex (dest_Collect t)
@@ -210,7 +300,9 @@
     val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
     val t = mk_split_abs (rev ((x, T) :: vs)) pat (reorder_bounds pats f)
   in
-    (fm, mk_image t (mk_set fm))
+    if the_default false (try is_reordering t) andalso is_collect_atom fm then
+      error "mk_pointfree_expr: trivial case" 
+    else (fm, mk_image t (mk_set fm))
   end;
 
 val rewrite_term = try mk_pointfree_expr
@@ -236,6 +328,9 @@
 val vimageE' =
   @{lemma "a \<notin> f -` B ==> (\<And> x. f a = x ==> x \<notin> B ==> P) ==> P" by simp}
 
+val collectI' = @{lemma "\<not> P a ==> a \<notin> {x. P x}" by auto}
+val collectE' = @{lemma "a \<notin> {x. P x} ==> (\<not> P a ==> Q) ==> Q" by auto}
+
 val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
   THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE}))
   THEN' REPEAT_DETERM o etac @{thm conjE}
@@ -265,8 +360,13 @@
     THEN' tac1_of_formula fm2
   | tac1_of_formula (Atom _) =
     REPEAT_DETERM1 o (rtac @{thm SigmaI}
+      ORELSE' ((rtac @{thm CollectI} ORELSE' rtac collectI') THEN' TRY o Simplifier.simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]))
       ORELSE' ((rtac @{thm vimageI2} ORELSE' rtac vimageI2') THEN'
         TRY o Simplifier.simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}])) 
+      ORELSE' (rtac @{thm image_eqI} THEN'
+    (REPEAT_DETERM o
+      (rtac @{thm refl}
+      ORELSE' rtac @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]})))
       ORELSE' rtac @{thm UNIV_I}
       ORELSE' rtac @{thm iffD2[OF Compl_iff]}
       ORELSE' atac)
@@ -282,13 +382,21 @@
     THEN' rtac @{thm disjI2}
     THEN' tac2_of_formula fm2
   | tac2_of_formula (Atom _) =
-    TRY o REPEAT_DETERM1 o
+    REPEAT_DETERM o
       (dtac @{thm iffD1[OF mem_Sigma_iff]}
+       ORELSE' ((etac @{thm CollectE} ORELSE' etac collectE') THEN'
+         TRY o Simplifier.full_simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}]) THEN'
+         REPEAT_DETERM o etac @{thm Pair_inject} THEN' TRY o hyp_subst_tac THEN' TRY o rtac @{thm refl})
+       ORELSE' (REPEAT_DETERM1 o etac @{thm imageE}
+         THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
+         THEN' REPEAT_DETERM o etac @{thm Pair_inject}
+         THEN' TRY o hyp_subst_tac)
        ORELSE' etac @{thm conjE}
        ORELSE' etac @{thm ComplE}
        ORELSE' ((etac @{thm vimageE} ORELSE' etac vimageE')
         THEN' TRY o Simplifier.full_simp_tac (HOL_basic_ss addsimps [@{thm prod.cases}])
         THEN' TRY o hyp_subst_tac)
+       ORELSE' (REPEAT_DETERM1 o etac @{thm Pair_inject} THEN' TRY o hyp_subst_tac)
        ORELSE' atac)
 
 fun tac ctxt fm =
@@ -302,7 +410,7 @@
       THEN' rtac @{thm iffD2[OF mem_Collect_eq]}
       THEN' REPEAT_DETERM o resolve_tac @{thms exI}
       THEN' (TRY o REPEAT_ALL_NEW (rtac @{thm conjI}))
-      THEN' (K (TRY (SOMEGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl}))))
+      THEN' (K (TRY (FIRSTGOAL ((TRY o hyp_subst_tac) THEN' rtac @{thm refl}))))
       THEN' (fn i => EVERY (rev (map_index (fn (j, f) =>
         REPEAT_DETERM (etac @{thm IntE} (i + j)) THEN tac2_of_formula f (i + j)) (strip_Int fm))))
   in