src/HOL/Tools/list_to_set_comprehension.ML
changeset 50425 79858bd9f5ef
parent 50421 eb7b59cc8e08
parent 50424 7c8ce63a3c00
child 50426 d2c60ada3ece
equal deleted inserted replaced
50421:eb7b59cc8e08 50425:79858bd9f5ef
     1 (*  Title:      HOL/Tools/list_to_set_comprehension.ML
       
     2     Author:     Lukas Bulwahn, TU Muenchen
       
     3 
       
     4 Simproc for rewriting list comprehensions applied to List.set to set
       
     5 comprehension.
       
     6 *)
       
     7 
       
     8 signature LIST_TO_SET_COMPREHENSION =
       
     9 sig
       
    10   val simproc : simpset -> cterm -> thm option
       
    11 end
       
    12 
       
    13 structure List_to_Set_Comprehension : LIST_TO_SET_COMPREHENSION =
       
    14 struct
       
    15 
       
    16 (* conversion *)
       
    17 
       
    18 fun all_exists_conv cv ctxt ct =
       
    19   (case Thm.term_of ct of
       
    20     Const (@{const_name HOL.Ex}, _) $ Abs _ =>
       
    21       Conv.arg_conv (Conv.abs_conv (all_exists_conv cv o #2) ctxt) ct
       
    22   | _ => cv ctxt ct)
       
    23 
       
    24 fun all_but_last_exists_conv cv ctxt ct =
       
    25   (case Thm.term_of ct of
       
    26     Const (@{const_name HOL.Ex}, _) $ Abs (_, _, Const (@{const_name HOL.Ex}, _) $ _) =>
       
    27       Conv.arg_conv (Conv.abs_conv (all_but_last_exists_conv cv o #2) ctxt) ct
       
    28   | _ => cv ctxt ct)
       
    29 
       
    30 fun Collect_conv cv ctxt ct =
       
    31   (case Thm.term_of ct of
       
    32     Const (@{const_name Set.Collect}, _) $ Abs _ => Conv.arg_conv (Conv.abs_conv cv ctxt) ct
       
    33   | _ => raise CTERM ("Collect_conv", [ct]))
       
    34 
       
    35 fun Trueprop_conv cv ct =
       
    36   (case Thm.term_of ct of
       
    37     Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct
       
    38   | _ => raise CTERM ("Trueprop_conv", [ct]))
       
    39 
       
    40 fun eq_conv cv1 cv2 ct =
       
    41   (case Thm.term_of ct of
       
    42     Const (@{const_name HOL.eq}, _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv1) cv2 ct
       
    43   | _ => raise CTERM ("eq_conv", [ct]))
       
    44 
       
    45 fun conj_conv cv1 cv2 ct =
       
    46   (case Thm.term_of ct of
       
    47     Const (@{const_name HOL.conj}, _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv1) cv2 ct
       
    48   | _ => raise CTERM ("conj_conv", [ct]))
       
    49 
       
    50 fun rewr_conv' th = Conv.rewr_conv (mk_meta_eq th)
       
    51 
       
    52 fun conjunct_assoc_conv ct =
       
    53   Conv.try_conv
       
    54     (rewr_conv' @{thm conj_assoc} then_conv conj_conv Conv.all_conv conjunct_assoc_conv) ct
       
    55 
       
    56 fun right_hand_set_comprehension_conv conv ctxt =
       
    57   Trueprop_conv (eq_conv Conv.all_conv
       
    58     (Collect_conv (all_exists_conv conv o #2) ctxt))
       
    59 
       
    60 
       
    61 (* term abstraction of list comprehension patterns *)
       
    62 
       
    63 datatype termlets = If | Case of (typ * int)
       
    64 
       
    65 fun simproc ss redex =
       
    66   let
       
    67     val ctxt = Simplifier.the_context ss
       
    68     val thy = Proof_Context.theory_of ctxt
       
    69     val set_Nil_I = @{thm trans} OF [@{thm set.simps(1)}, @{thm empty_def}]
       
    70     val set_singleton = @{lemma "set [a] = {x. x = a}" by simp}
       
    71     val inst_Collect_mem_eq = @{lemma "set A = {x. x : set A}" by simp}
       
    72     val del_refl_eq = @{lemma "(t = t & P) == P" by simp}
       
    73     fun mk_set T = Const (@{const_name List.set}, HOLogic.listT T --> HOLogic.mk_setT T)
       
    74     fun dest_set (Const (@{const_name List.set}, _) $ xs) = xs
       
    75     fun dest_singleton_list (Const (@{const_name List.Cons}, _)
       
    76           $ t $ (Const (@{const_name List.Nil}, _))) = t
       
    77       | dest_singleton_list t = raise TERM ("dest_singleton_list", [t])
       
    78     (* We check that one case returns a singleton list and all other cases
       
    79        return [], and return the index of the one singleton list case *)
       
    80     fun possible_index_of_singleton_case cases =
       
    81       let
       
    82         fun check (i, case_t) s =
       
    83           (case strip_abs_body case_t of
       
    84             (Const (@{const_name List.Nil}, _)) => s
       
    85           | _ => (case s of NONE => SOME i | SOME _ => NONE))
       
    86       in
       
    87         fold_index check cases NONE
       
    88       end
       
    89     (* returns (case_expr type index chosen_case) option  *)
       
    90     fun dest_case case_term =
       
    91       let
       
    92         val (case_const, args) = strip_comb case_term
       
    93       in
       
    94         (case try dest_Const case_const of
       
    95           SOME (c, T) =>
       
    96             (case Datatype.info_of_case thy c of
       
    97               SOME _ =>
       
    98                 (case possible_index_of_singleton_case (fst (split_last args)) of
       
    99                   SOME i =>
       
   100                     let
       
   101                       val (Ts, _) = strip_type T
       
   102                       val T' = List.last Ts
       
   103                     in SOME (List.last args, T', i, nth args i) end
       
   104                 | NONE => NONE)
       
   105             | NONE => NONE)
       
   106         | NONE => NONE)
       
   107       end
       
   108     (* returns condition continuing term option *)
       
   109     fun dest_if (Const (@{const_name If}, _) $ cond $ then_t $ Const (@{const_name Nil}, _)) =
       
   110           SOME (cond, then_t)
       
   111       | dest_if _ = NONE
       
   112     fun tac _ [] = rtac set_singleton 1 ORELSE rtac inst_Collect_mem_eq 1
       
   113       | tac ctxt (If :: cont) =
       
   114           Splitter.split_tac [@{thm split_if}] 1
       
   115           THEN rtac @{thm conjI} 1
       
   116           THEN rtac @{thm impI} 1
       
   117           THEN Subgoal.FOCUS (fn {prems, context, ...} =>
       
   118             CONVERSION (right_hand_set_comprehension_conv (K
       
   119               (conj_conv (Conv.rewr_conv (List.last prems RS @{thm Eq_TrueI})) Conv.all_conv
       
   120                then_conv
       
   121                rewr_conv' @{lemma "(True & P) = P" by simp})) context) 1) ctxt 1
       
   122           THEN tac ctxt cont
       
   123           THEN rtac @{thm impI} 1
       
   124           THEN Subgoal.FOCUS (fn {prems, context, ...} =>
       
   125               CONVERSION (right_hand_set_comprehension_conv (K
       
   126                 (conj_conv (Conv.rewr_conv (List.last prems RS @{thm Eq_FalseI})) Conv.all_conv
       
   127                  then_conv rewr_conv' @{lemma "(False & P) = False" by simp})) context) 1) ctxt 1
       
   128           THEN rtac set_Nil_I 1
       
   129       | tac ctxt (Case (T, i) :: cont) =
       
   130           let
       
   131             val info = Datatype.the_info thy (fst (dest_Type T))
       
   132           in
       
   133             (* do case distinction *)
       
   134             Splitter.split_tac [#split info] 1
       
   135             THEN EVERY (map_index (fn (i', _) =>
       
   136               (if i' < length (#case_rewrites info) - 1 then rtac @{thm conjI} 1 else all_tac)
       
   137               THEN REPEAT_DETERM (rtac @{thm allI} 1)
       
   138               THEN rtac @{thm impI} 1
       
   139               THEN (if i' = i then
       
   140                 (* continue recursively *)
       
   141                 Subgoal.FOCUS (fn {prems, context, ...} =>
       
   142                   CONVERSION (Thm.eta_conversion then_conv right_hand_set_comprehension_conv (K
       
   143                       ((conj_conv
       
   144                         (eq_conv Conv.all_conv (rewr_conv' (List.last prems)) then_conv
       
   145                           (Conv.try_conv (Conv.rewrs_conv (map mk_meta_eq (#inject info)))))
       
   146                         Conv.all_conv)
       
   147                         then_conv (Conv.try_conv (Conv.rewr_conv del_refl_eq))
       
   148                         then_conv conjunct_assoc_conv)) context
       
   149                     then_conv (Trueprop_conv (eq_conv Conv.all_conv (Collect_conv (fn (_, ctxt) =>
       
   150                       Conv.repeat_conv
       
   151                         (all_but_last_exists_conv
       
   152                           (K (rewr_conv'
       
   153                             @{lemma "(EX x. x = t & P x) = P t" by simp})) ctxt)) context)))) 1) ctxt 1
       
   154                 THEN tac ctxt cont
       
   155               else
       
   156                 Subgoal.FOCUS (fn {prems, context, ...} =>
       
   157                   CONVERSION
       
   158                     (right_hand_set_comprehension_conv (K
       
   159                       (conj_conv
       
   160                         ((eq_conv Conv.all_conv
       
   161                           (rewr_conv' (List.last prems))) then_conv
       
   162                           (Conv.rewrs_conv (map (fn th => th RS @{thm Eq_FalseI}) (#distinct info))))
       
   163                         Conv.all_conv then_conv
       
   164                         (rewr_conv' @{lemma "(False & P) = False" by simp}))) context then_conv
       
   165                       Trueprop_conv
       
   166                         (eq_conv Conv.all_conv
       
   167                           (Collect_conv (fn (_, ctxt) =>
       
   168                             Conv.repeat_conv
       
   169                               (Conv.bottom_conv
       
   170                                 (K (rewr_conv'
       
   171                                   @{lemma "(EX x. P) = P" by simp})) ctxt)) context))) 1) ctxt 1
       
   172                 THEN rtac set_Nil_I 1)) (#case_rewrites info))
       
   173           end
       
   174     fun make_inner_eqs bound_vs Tis eqs t =
       
   175       (case dest_case t of
       
   176         SOME (x, T, i, cont) =>
       
   177           let
       
   178             val (vs, body) = strip_abs (Pattern.eta_long (map snd bound_vs) cont)
       
   179             val x' = incr_boundvars (length vs) x
       
   180             val eqs' = map (incr_boundvars (length vs)) eqs
       
   181             val (constr_name, _) = nth (the (Datatype.get_constrs thy (fst (dest_Type T)))) i
       
   182             val constr_t =
       
   183               list_comb
       
   184                 (Const (constr_name, map snd vs ---> T), map Bound (((length vs) - 1) downto 0))
       
   185             val constr_eq = Const (@{const_name HOL.eq}, T --> T --> @{typ bool}) $ constr_t $ x'
       
   186           in
       
   187             make_inner_eqs (rev vs @ bound_vs) (Case (T, i) :: Tis) (constr_eq :: eqs') body
       
   188           end
       
   189       | NONE =>
       
   190           (case dest_if t of
       
   191             SOME (condition, cont) => make_inner_eqs bound_vs (If :: Tis) (condition :: eqs) cont
       
   192           | NONE =>
       
   193             if eqs = [] then NONE (* no rewriting, nothing to be done *)
       
   194             else
       
   195               let
       
   196                 val Type (@{type_name List.list}, [rT]) = fastype_of1 (map snd bound_vs, t)
       
   197                 val pat_eq =
       
   198                   (case try dest_singleton_list t of
       
   199                     SOME t' =>
       
   200                       Const (@{const_name HOL.eq}, rT --> rT --> @{typ bool}) $
       
   201                         Bound (length bound_vs) $ t'
       
   202                   | NONE =>
       
   203                       Const (@{const_name Set.member}, rT --> HOLogic.mk_setT rT --> @{typ bool}) $
       
   204                         Bound (length bound_vs) $ (mk_set rT $ t))
       
   205                 val reverse_bounds = curry subst_bounds
       
   206                   ((map Bound ((length bound_vs - 1) downto 0)) @ [Bound (length bound_vs)])
       
   207                 val eqs' = map reverse_bounds eqs
       
   208                 val pat_eq' = reverse_bounds pat_eq
       
   209                 val inner_t =
       
   210                   fold (fn (_, T) => fn t => HOLogic.exists_const T $ absdummy T t)
       
   211                     (rev bound_vs) (fold (curry HOLogic.mk_conj) eqs' pat_eq')
       
   212                 val lhs = term_of redex
       
   213                 val rhs = HOLogic.mk_Collect ("x", rT, inner_t)
       
   214                 val rewrite_rule_t = HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, rhs))
       
   215               in
       
   216                 SOME
       
   217                   ((Goal.prove ctxt [] [] rewrite_rule_t
       
   218                     (fn {context, ...} => tac context (rev Tis))) RS @{thm eq_reflection})
       
   219               end))
       
   220   in
       
   221     make_inner_eqs [] [] [] (dest_set (term_of redex))
       
   222   end
       
   223 
       
   224 end