--- a/src/HOL/Tools/set_comprehension_pointfree.ML Thu Nov 08 10:02:38 2012 +0100
+++ b/src/HOL/Tools/set_comprehension_pointfree.ML Thu Nov 08 11:59:46 2012 +0100
@@ -103,28 +103,25 @@
(* patterns *)
-datatype pattern = TBound of int | TPair of pattern * pattern;
+datatype pattern = Pattern of int list
-fun mk_pattern (Bound n) = TBound n
- | mk_pattern (Const (@{const_name "Product_Type.Pair"}, _) $ l $ r) =
- TPair (mk_pattern l, mk_pattern r)
- | mk_pattern t = raise TERM ("mk_pattern: only bound variable tuples currently supported", [t]);
+fun dest_Pattern (Pattern bs) = bs
-fun type_of_pattern Ts (TBound n) = nth Ts n
- | type_of_pattern Ts (TPair (l, r)) = HOLogic.mk_prodT (type_of_pattern Ts l, type_of_pattern Ts r)
+fun dest_bound (Bound i) = i
+ | dest_bound t = raise TERM("dest_bound", [t]);
-fun term_of_pattern _ (TBound n) = Bound n
- | term_of_pattern Ts (TPair (l, r)) =
+fun mk_pattern t = case try ((map dest_bound) o HOLogic.strip_tuple) t of
+ SOME p => Pattern p
+ | NONE => raise TERM ("mk_pattern: only tuples of bound variables supported", [t]);
+
+fun type_of_pattern Ts (Pattern bs) = HOLogic.mk_tupleT (map (nth Ts) bs)
+
+fun term_of_pattern Ts (Pattern bs) =
let
- val (lt, rt) = pairself (term_of_pattern Ts) (l, r)
- val (lT, rT) = pairself (curry fastype_of1 Ts) (lt, rt)
- in
- HOLogic.pair_const lT rT $ lt $ rt
- end;
-
-fun bounds_of_pattern (TBound i) = [i]
- | bounds_of_pattern (TPair (l, r)) = union (op =) (bounds_of_pattern l) (bounds_of_pattern r)
-
+ fun mk [b] = Bound b
+ | mk (b :: bs) = HOLogic.pair_const (nth Ts b) (type_of_pattern Ts (Pattern bs))
+ $ Bound b $ mk bs
+ in mk bs end;
(* formulas *)
@@ -144,7 +141,7 @@
val bs = loose_bnos x
val vs' = map (nth (rev vs)) bs
val x' = subst_atomic (map_index (fn (i, j) => (Bound j, Bound i)) (rev bs)) x
- val tuple = foldr1 TPair (map TBound bs)
+ val tuple = Pattern bs
val rT = HOLogic.dest_setT (fastype_of s)
fun mk_split [(x, T)] t = (T, Abs (x, T, t))
| mk_split ((x, T) :: vs) t =
@@ -159,8 +156,8 @@
fun can_merge (pats1, pats2) =
let
- fun check pat1 pat2 = (pat1 = pat2)
- orelse (inter (op =) (bounds_of_pattern pat1) (bounds_of_pattern pat2) = [])
+ fun check (Pattern pat1) (Pattern pat2) = (pat1 = pat2)
+ orelse (null (inter (op =) pat1 pat2))
in
forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1
end
@@ -183,7 +180,7 @@
fun reorder_bounds pats t =
let
- val bounds = maps bounds_of_pattern pats
+ val bounds = maps dest_Pattern pats
val bperm = bounds ~~ ((length bounds - 1) downto 0)
|> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
in