extending the setcomprehension_pointfree simproc to handle nesting disjunctions, conjunctions and negations (with contributions from Rafal Kolanski, NICTA); tuned
authorbulwahn
Sun, 14 Oct 2012 19:16:32 +0200
changeset 49849 d9822ec4f434
parent 49848 f222a054342e
child 49850 873fa7156468
extending the setcomprehension_pointfree simproc to handle nesting disjunctions, conjunctions and negations (with contributions from Rafal Kolanski, NICTA); tuned
src/HOL/Tools/set_comprehension_pointfree.ML
--- a/src/HOL/Tools/set_comprehension_pointfree.ML	Sat Oct 13 21:09:20 2012 +0200
+++ b/src/HOL/Tools/set_comprehension_pointfree.ML	Sun Oct 14 19:16:32 2012 +0200
@@ -7,16 +7,15 @@
 
 signature SET_COMPREHENSION_POINTFREE =
 sig
+  val base_simproc : simpset -> cterm -> thm option
   val code_simproc : simpset -> cterm -> thm option
   val simproc : simpset -> cterm -> thm option
-  val rewrite_term : term -> term option
-  (* FIXME: function conv is not a conversion, i.e. of type cterm -> thm, MAYBE rename *)
-  val conv : Proof.context -> term -> thm option
 end
 
 structure Set_Comprehension_Pointfree : SET_COMPREHENSION_POINTFREE =
 struct
 
+
 (* syntactic operations *)
 
 fun mk_inf (t1, t2) =
@@ -59,9 +58,6 @@
       T1 --> (setT --> T2) --> resT) $ t1 $ absdummy setT t2
   end;
 
-fun dest_Bound (Bound x) = x
-  | dest_Bound t = raise TERM("dest_Bound", [t]);
-
 fun dest_Collect (Const (@{const_name Collect}, _) $ Abs (_, _, t)) = t
   | dest_Collect t = raise TERM ("dest_Collect", [t])
 
@@ -74,92 +70,213 @@
   end
   | strip_ex t = ([], t)
 
-fun list_tupled_abs [] f = f
-  | list_tupled_abs [(n, T)] f = (Abs (n, T, f))
-  | list_tupled_abs ((n, T)::v::vs) f =
-      HOLogic.mk_split (Abs (n, T, list_tupled_abs (v::vs) f))
+fun mk_prod1 Ts (t1, t2) =
+  let
+    val (T1, T2) = pairself (curry fastype_of1 Ts) (t1, t2)
+  in
+    HOLogic.pair_const T1 T2 $ t1 $ t2
+  end;
+
+
+(* patterns *)
+
+datatype pattern = TBound of int | TPair of pattern * pattern;
+
+fun mk_pattern (Bound n) = TBound n
+  | mk_pattern (Const (@{const_name "Product_Type.Pair"}, _) $ l $ r) =
+      TPair (mk_pattern l, mk_pattern r)
+  | mk_pattern t = raise TERM ("mk_pattern: only bound variable tuples currently supported", [t]);
+
+fun type_of_pattern Ts (TBound n) = nth Ts n
+  | type_of_pattern Ts (TPair (l, r)) = HOLogic.mk_prodT (type_of_pattern Ts l, type_of_pattern Ts r)
+
+fun term_of_pattern _ (TBound n) = Bound n
+  | term_of_pattern Ts (TPair (l, r)) =
+    let
+      val (lt, rt) = pairself (term_of_pattern Ts) (l, r)
+      val (lT, rT) = pairself (curry fastype_of1 Ts) (lt, rt) 
+    in
+      HOLogic.pair_const lT rT $ lt $ rt
+    end;
+
+fun bounds_of_pattern (TBound i) = [i]
+  | bounds_of_pattern (TPair (l, r)) = union (op =) (bounds_of_pattern l) (bounds_of_pattern r)
+
+
+(* formulas *)
+
+datatype formula = Atom of (pattern * term) | Int of formula * formula | Un of formula * formula
+
+fun mk_atom (Const (@{const_name "Set.member"}, _) $ x $ s) = (mk_pattern x, Atom (mk_pattern x, s))
+  | mk_atom (Const (@{const_name "HOL.Not"}, _) $ (Const (@{const_name "Set.member"}, _) $ x $ s)) =
+      (mk_pattern x, Atom (mk_pattern x, mk_Compl s))
+
+fun can_merge (pats1, pats2) =
+  let
+    fun check pat1 pat2 = (pat1 = pat2)
+      orelse (inter (op =) (bounds_of_pattern pat1) (bounds_of_pattern pat2) = [])
+  in
+    forall (fn pat1 => forall (fn pat2 => check pat1 pat2) pats2) pats1 
+  end
+
+fun merge_patterns (pats1, pats2) =
+  if can_merge (pats1, pats2) then
+    union (op =) pats1 pats2
+  else raise Fail "merge_patterns: variable groups overlap"
+
+fun merge oper (pats1, sp1) (pats2, sp2) = (merge_patterns (pats1, pats2), oper (sp1, sp2))
+
+fun mk_formula (@{const HOL.conj} $ t1 $ t2) = merge Int (mk_formula t1) (mk_formula t2)
+  | mk_formula (@{const HOL.disj} $ t1 $ t2) = merge Un (mk_formula t1) (mk_formula t2)
+  | mk_formula t = apfst single (mk_atom t)
+
+
+(* term construction *)
+
+fun reorder_bounds pats t =
+  let
+    val bounds = maps bounds_of_pattern pats
+    val bperm = bounds ~~ ((length bounds - 1) downto 0)
+      |> sort (fn (i,j) => int_ord (fst i, fst j)) |> map snd
+  in
+    subst_bounds (map Bound bperm, t)
+  end;
+
+fun mk_split_abs vs (Bound i) t = let val (x, T) = nth vs i in Abs (x, T, t) end
+  | mk_split_abs vs (Const ("Product_Type.Pair", _) $ u $ v) t =
+      HOLogic.mk_split (mk_split_abs vs u (mk_split_abs vs v t))
+  | mk_split_abs _ t _ = raise TERM ("mk_split_abs: bad term", [t]);
 
 fun mk_pointfree_expr t =
   let
     val (vs, t'') = strip_ex (dest_Collect t)
+    val Ts = map snd (rev vs)
+    fun mk_mem_UNIV n = HOLogic.mk_mem (Bound n, HOLogic.mk_UNIV (nth Ts n))
+    fun lookup (pat', t) pat = if pat = pat' then t else HOLogic.mk_UNIV (type_of_pattern Ts pat)
     val conjs = HOLogic.dest_conj t''
     val is_the_eq =
       the_default false o (try (fn eq => fst (HOLogic.dest_eq eq) = Bound (length vs)))
     val SOME eq = find_first is_the_eq conjs
     val f = snd (HOLogic.dest_eq eq)
     val conjs' = filter_out (fn t => eq = t) conjs
-    val mems = map (apfst dest_Bound o HOLogic.dest_mem) conjs'
-    val grouped_mems = AList.group (op =) mems
-    fun mk_grouped_unions (i, T) =
-      case AList.lookup (op =) grouped_mems i of
-        SOME ts => foldr1 mk_inf ts
-      | NONE => HOLogic.mk_UNIV T
-    val complete_sets = map mk_grouped_unions ((length vs - 1) downto 0 ~~ map snd vs)
+    val unused_bounds = subtract (op =) (distinct (op =) (maps loose_bnos conjs'))
+      (0 upto (length vs - 1))
+    val (pats, fm) =
+      mk_formula (foldr1 HOLogic.mk_conj (conjs' @ map mk_mem_UNIV unused_bounds))
+    fun mk_set (Atom pt) = (case map (lookup pt) pats of [t'] => t' | ts => foldr1 mk_sigma ts)
+      | mk_set (Un (f1, f2)) = mk_sup (mk_set f1, mk_set f2)
+      | mk_set (Int (f1, f2)) = mk_inf (mk_set f1, mk_set f2)
+    val pat = foldr1 (mk_prod1 Ts) (map (term_of_pattern Ts) pats)
+    val t = mk_split_abs (rev vs) pat (reorder_bounds pats f)
   in
-    mk_image (list_tupled_abs vs f) (foldr1 mk_sigma complete_sets)
+    (fm, mk_image t (mk_set fm))
   end;
 
 val rewrite_term = try mk_pointfree_expr
 
+
 (* proof tactic *)
 
-(* Tactic works for arbitrary number of m : S conjuncts *)
+val prod_case_distrib = @{lemma "(prod_case g x) z = prod_case (% x y. (g x y) z) x" by (simp add: prod_case_beta)}
+
+(* FIXME: one of many clones *)
+fun Trueprop_conv cv ct =
+  (case Thm.term_of ct of
+    Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct
+  | _ => raise CTERM ("Trueprop_conv", [ct]))
 
-val dest_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
-  THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE conjE}))
+(* FIXME: another clone *)
+fun eq_conv cv1 cv2 ct =
+  (case Thm.term_of ct of
+    Const (@{const_name HOL.eq}, _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv1) cv2 ct
+  | _ => raise CTERM ("eq_conv", [ct]))
+
+val elim_Collect_tac = dtac @{thm iffD1[OF mem_Collect_eq]}
+  THEN' (REPEAT_DETERM o (eresolve_tac @{thms exE}))
+  THEN' TRY o etac @{thm conjE}
   THEN' hyp_subst_tac;
 
-val intro_image_Sigma_tac = rtac @{thm image_eqI}
+fun intro_image_tac ctxt = rtac @{thm image_eqI}
     THEN' (REPEAT_DETERM1 o
       (rtac @{thm refl}
       ORELSE' rtac
-        @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}));
+        @{thm arg_cong2[OF refl, where f="op =", OF prod.cases, THEN iffD2]}
+      ORELSE' CONVERSION (Conv.params_conv ~1 (K (Conv.concl_conv ~1
+        (Trueprop_conv (eq_conv Conv.all_conv (Conv.rewr_conv (mk_meta_eq prod_case_distrib)))))) ctxt)))
 
-val dest_image_Sigma_tac = etac @{thm imageE}
+val elim_image_tac = etac @{thm imageE}
   THEN' (TRY o REPEAT_DETERM1 o Splitter.split_asm_tac @{thms prod.split_asm})
   THEN' hyp_subst_tac
-  THEN' (TRY o REPEAT_DETERM1 o
-    (etac @{thm conjE} ORELSE' dtac @{thm iffD1[OF mem_Sigma_iff]}));
 
 val intro_Collect_tac = rtac @{thm iffD2[OF mem_Collect_eq]}
   THEN' REPEAT_DETERM1 o resolve_tac @{thms exI}
-  THEN' (TRY o REPEAT_ALL_NEW (rtac @{thm conjI}))
-  THEN' (K (ALLGOALS (TRY o ((TRY o hyp_subst_tac) THEN' rtac @{thm refl}))))
+  THEN' (TRY o (rtac @{thm conjI}))
+  THEN' (TRY o hyp_subst_tac)
+  THEN' rtac @{thm refl};
 
-val tac =
+fun tac1_of_formula (Int (fm1, fm2)) =
+    TRY o etac @{thm conjE}
+    THEN' rtac @{thm IntI}
+    THEN' (fn i => tac1_of_formula fm2 (i + 1))
+    THEN' tac1_of_formula fm1
+  | tac1_of_formula (Un (fm1, fm2)) =
+    etac @{thm disjE} THEN' rtac @{thm UnI1}
+    THEN' tac1_of_formula fm1
+    THEN' rtac @{thm UnI2}
+    THEN' tac1_of_formula fm2
+  | tac1_of_formula (Atom _) =
+    (REPEAT_DETERM1 o (rtac @{thm SigmaI}
+      ORELSE' rtac @{thm UNIV_I}
+      ORELSE' rtac @{thm iffD2[OF Compl_iff]}
+      ORELSE' atac))
+
+fun tac2_of_formula (Int (fm1, fm2)) =
+    TRY o etac @{thm IntE}
+    THEN' TRY o rtac @{thm conjI}
+    THEN' (fn i => tac2_of_formula fm2 (i + 1))
+    THEN' tac2_of_formula fm1
+  | tac2_of_formula (Un (fm1, fm2)) =
+    etac @{thm UnE} THEN' rtac @{thm disjI1}
+    THEN' tac2_of_formula fm1
+    THEN' rtac @{thm disjI2}
+    THEN' tac2_of_formula fm2
+  | tac2_of_formula (Atom _) =
+    TRY o REPEAT_DETERM1 o
+      (dtac @{thm iffD1[OF mem_Sigma_iff]}
+       ORELSE' etac @{thm conjE}
+       ORELSE' etac @{thm ComplE}
+       ORELSE' atac)
+
+fun tac ctxt fm =
   let
     val subset_tac1 = rtac @{thm subsetI}
-      THEN' dest_Collect_tac
-      THEN' intro_image_Sigma_tac
-      THEN' (REPEAT_DETERM1 o
-        (rtac @{thm SigmaI}
-        ORELSE' rtac @{thm UNIV_I}
-        ORELSE' rtac @{thm IntI}
-        ORELSE' atac));
-
+      THEN' elim_Collect_tac
+      THEN' (intro_image_tac ctxt)
+      THEN' (tac1_of_formula fm)
     val subset_tac2 = rtac @{thm subsetI}
-      THEN' dest_image_Sigma_tac
+      THEN' elim_image_tac
       THEN' intro_Collect_tac
-      THEN' REPEAT_DETERM o (eresolve_tac @{thms IntD1 IntD2} ORELSE' atac);
+      THEN' tac2_of_formula fm
   in
     rtac @{thm subset_antisym} THEN' subset_tac1 THEN' subset_tac2
   end;
 
+
+(* main simprocs *)
+
 fun conv ctxt t =
   let
     val ct = cterm_of (Proof_Context.theory_of ctxt) t
     val Bex_def = mk_meta_eq @{thm Bex_def}
     val unfold_eq = Conv.bottom_conv (K (Conv.try_conv (Conv.rewr_conv Bex_def))) ctxt ct
-    val t' = term_of (Thm.rhs_of unfold_eq) 
-    fun mk_thm t'' = Goal.prove ctxt [] []
-      (HOLogic.mk_Trueprop (HOLogic.mk_eq (t', t''))) (K (tac 1))
+    val t' = term_of (Thm.rhs_of unfold_eq)
+    fun mk_thm (fm, t'') = Goal.prove ctxt [] []
+      (HOLogic.mk_Trueprop (HOLogic.mk_eq (t', t''))) (fn {context, ...} => tac context fm 1)
     fun unfold th = th RS ((unfold_eq RS meta_eq_to_obj_eq) RS @{thm trans})
   in
     Option.map (unfold o mk_thm) (rewrite_term t')
   end;
 
-(* simproc *)
-
 fun base_simproc ss redex =
   let
     val ctxt = Simplifier.the_context ss