src/HOL/Library/case_converter.ML
author haftmann
Fri Mar 22 19:18:08 2019 +0000 (4 months ago)
changeset 69946 494934c30f38
parent 69593 3dda49e08b9d
permissions -rw-r--r--
improved code equations taken over from AFP
Andreas@68155
     1
(* Author: Pascal Stoop, ETH Zurich
Andreas@68155
     2
   Author: Andreas Lochbihler, Digital Asset *)
Andreas@68155
     3
Andreas@68155
     4
signature CASE_CONVERTER =
Andreas@68155
     5
sig
Andreas@69568
     6
  type elimination_strategy
Andreas@69568
     7
  val to_case: Proof.context -> elimination_strategy -> (string * typ -> int) ->
Andreas@68155
     8
    thm list -> thm list option
Andreas@69568
     9
  val replace_by_type: (Proof.context -> string * string -> bool) -> elimination_strategy
Andreas@69568
    10
  val keep_constructor_context: elimination_strategy
Andreas@68155
    11
end;
Andreas@68155
    12
Andreas@68155
    13
structure Case_Converter : CASE_CONVERTER =
Andreas@68155
    14
struct
Andreas@68155
    15
Andreas@68155
    16
fun lookup_remove _ _ [] = (NONE, [])
Andreas@68155
    17
  | lookup_remove eq k ((k', v) :: kvs) =
Andreas@68155
    18
    if eq (k, k') then (SOME (k', v), kvs)
Andreas@68155
    19
    else apsnd (cons (k', v)) (lookup_remove eq k kvs)
Andreas@68155
    20
Andreas@68155
    21
fun mk_abort msg t =
Andreas@68155
    22
  let 
Andreas@68155
    23
    val T = fastype_of t
wenzelm@69593
    24
    val abort = Const (\<^const_name>\<open>missing_pattern_match\<close>, HOLogic.literalT --> (HOLogic.unitT --> T) --> T)
Andreas@68155
    25
  in
Andreas@68155
    26
    abort $ HOLogic.mk_literal msg $ absdummy HOLogic.unitT t
Andreas@68155
    27
  end
Andreas@68155
    28
Andreas@68155
    29
(* fold_term : (string * typ -> 'a) ->
Andreas@68155
    30
               (string * typ -> 'a) ->
Andreas@68155
    31
               (indexname * typ -> 'a) ->
Andreas@68155
    32
               (int -> 'a) ->
Andreas@68155
    33
               (string * typ * 'a -> 'a) ->
Andreas@68155
    34
               ('a * 'a -> 'a) ->
Andreas@68155
    35
               term ->
Andreas@68155
    36
               'a *)
Andreas@68155
    37
fun fold_term const_fun free_fun var_fun bound_fun abs_fun dollar_fun term =
Andreas@68155
    38
  let
Andreas@68155
    39
    fun go x = case x of
Andreas@68155
    40
      Const (s, T) => const_fun (s, T)
Andreas@68155
    41
      | Free (s, T) => free_fun (s, T)
Andreas@68155
    42
      | Var (i, T) => var_fun (i, T)
Andreas@68155
    43
      | Bound n => bound_fun n
Andreas@68155
    44
      | Abs (s, T, term) => abs_fun (s, T, go term)
Andreas@68155
    45
      | term1 $ term2 => dollar_fun (go term1, go term2)
Andreas@68155
    46
  in
Andreas@68155
    47
    go term
Andreas@68155
    48
  end;
Andreas@68155
    49
Andreas@68155
    50
datatype term_coordinate = End of typ
Andreas@68155
    51
  | Coordinate of (string * (int * term_coordinate)) list;
Andreas@68155
    52
Andreas@68155
    53
fun term_coordinate_merge (End T) _ = End T
Andreas@68155
    54
  | term_coordinate_merge _ (End T) = End T
Andreas@68155
    55
  | term_coordinate_merge (Coordinate xs) (Coordinate ys) =
Andreas@68155
    56
  let
Andreas@68155
    57
    fun merge_consts xs [] = xs
Andreas@68155
    58
      | merge_consts xs ((s1, (n, y)) :: ys) = 
Andreas@68155
    59
        case List.partition (fn (s2, (m, _)) => s1 = s2 andalso n = m) xs of
Andreas@68155
    60
            ([], xs') => (s1, (n, y)) :: (merge_consts xs' ys)
Andreas@68155
    61
          | ((_, (_, x)) :: _, xs') => (s1, (n, term_coordinate_merge x y)) :: (merge_consts xs' ys)
Andreas@68155
    62
  in
Andreas@68155
    63
    Coordinate (merge_consts xs ys)
Andreas@68155
    64
  end;
Andreas@68155
    65
Andreas@68155
    66
fun coordinates_to_list (End x) = [(x, [])]
Andreas@68155
    67
  | coordinates_to_list (Coordinate xs) = 
Andreas@68155
    68
  let
Andreas@68155
    69
    fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
Andreas@68155
    70
  in flat (map f xs) end;
Andreas@68155
    71
Andreas@69568
    72
type elimination_strategy = Proof.context -> term list -> term_coordinate list
Andreas@69568
    73
Andreas@69568
    74
fun replace_by_type replace_ctr ctxt pats =
Andreas@69568
    75
  let
Andreas@69568
    76
    fun term_to_coordinates P term = 
Andreas@69568
    77
      let
Andreas@69568
    78
        val (ctr, args) = strip_comb term
Andreas@69568
    79
      in
Andreas@69568
    80
        case ctr of Const (s, T) =>
Andreas@69568
    81
          if P (body_type T |> dest_Type |> fst, s)
Andreas@69568
    82
          then SOME (End (body_type T))
Andreas@69568
    83
          else
Andreas@69568
    84
            let
Andreas@69568
    85
              fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
Andreas@69568
    86
              val tcos = map_filter I (map_index f args)
Andreas@69568
    87
            in
Andreas@69568
    88
              if null tcos then NONE
Andreas@69568
    89
              else SOME (Coordinate (map (pair s) tcos))
Andreas@69568
    90
            end
Andreas@69568
    91
        | _ => NONE
Andreas@69568
    92
      end
Andreas@69568
    93
    in
Andreas@69568
    94
      map_filter (term_to_coordinates (replace_ctr ctxt)) pats
Andreas@69568
    95
    end
Andreas@69568
    96
Andreas@69568
    97
fun keep_constructor_context ctxt pats =
Andreas@69568
    98
  let
Andreas@69568
    99
    fun to_coordinates [] = NONE
Andreas@69568
   100
      | to_coordinates pats =
Andreas@69568
   101
        let
Andreas@69568
   102
          val (fs, argss) = map strip_comb pats |> split_list
Andreas@69568
   103
          val f = hd fs
Andreas@69568
   104
          fun is_single_ctr (Const (name, T)) = 
Andreas@69568
   105
              let
Andreas@69568
   106
                val tyco = body_type T |> dest_Type |> fst
Andreas@69568
   107
                val _ = Ctr_Sugar.ctr_sugar_of ctxt tyco |> the |> #ctrs
Andreas@69568
   108
              in
Andreas@69568
   109
                case Ctr_Sugar.ctr_sugar_of ctxt tyco of
Andreas@69568
   110
                  NONE => error ("Not a free constructor " ^ name ^ " in pattern")
Andreas@69568
   111
                | SOME info =>
Andreas@69568
   112
                  case #ctrs info of [Const (name', _)] => name = name'
Andreas@69568
   113
                    | _ => false
Andreas@69568
   114
              end
Andreas@69568
   115
            | is_single_ctr _ = false
Andreas@69568
   116
        in 
Andreas@69568
   117
          if not (is_single_ctr f) andalso forall (fn x => f = x) fs then
Andreas@69568
   118
            let
Andreas@69568
   119
              val patss = Ctr_Sugar_Util.transpose argss
Andreas@69568
   120
              fun recurse (i, pats) = to_coordinates pats |> Option.map (pair i)
Andreas@69568
   121
              val coords = map_filter I (map_index recurse patss)
Andreas@69568
   122
            in
Andreas@69568
   123
              if null coords then NONE
Andreas@69568
   124
              else SOME (Coordinate (map (pair (dest_Const f |> fst)) coords))
Andreas@69568
   125
            end
Andreas@69568
   126
          else SOME (End (body_type (fastype_of f)))
Andreas@69568
   127
          end
Andreas@69568
   128
    in
Andreas@69568
   129
      the_list (to_coordinates pats)
Andreas@69568
   130
    end
Andreas@69568
   131
Andreas@68155
   132
Andreas@68155
   133
(* AL: TODO: change from term to const_name *)
Andreas@68155
   134
fun find_ctr ctr1 xs =
Andreas@68155
   135
  let
Andreas@68155
   136
    val const_name = fst o dest_Const
Andreas@68155
   137
    fun const_equal (ctr1, ctr2) = const_name ctr1 = const_name ctr2
Andreas@68155
   138
  in
Andreas@68155
   139
    lookup_remove const_equal ctr1 xs
Andreas@68155
   140
  end;
Andreas@68155
   141
Andreas@68155
   142
datatype pattern 
Andreas@68155
   143
  = Wildcard
Andreas@68155
   144
  | Value
Andreas@68155
   145
  | Split of int * (term * pattern) list * pattern;
Andreas@68155
   146
Andreas@68155
   147
fun pattern_merge Wildcard pat' = pat'
Andreas@68155
   148
  | pattern_merge Value _ = Value
Andreas@68155
   149
  | pattern_merge (Split (n, xs, pat)) Wildcard =
Andreas@68155
   150
    Split (n, map (apsnd (fn pat'' => pattern_merge pat'' Wildcard)) xs, pattern_merge pat Wildcard)
Andreas@68155
   151
  | pattern_merge (Split _) Value = Value
Andreas@68155
   152
  | pattern_merge (Split (n, xs, pat)) (Split (m, ys, pat'')) =
Andreas@68155
   153
    let 
Andreas@68155
   154
      fun merge_consts xs [] = map (apsnd (fn pat => pattern_merge pat Wildcard)) xs
Andreas@68155
   155
        | merge_consts xs ((ctr, y) :: ys) =
Andreas@68155
   156
          (case find_ctr ctr xs of
Andreas@68155
   157
              (SOME (ctr, x), xs) => (ctr, pattern_merge x y) :: merge_consts xs ys
Andreas@68155
   158
            | (NONE, xs) => (ctr, y) :: merge_consts xs ys
Andreas@68155
   159
          )
Andreas@68155
   160
     in
Andreas@68155
   161
       Split (if n <= 0 then m else n, merge_consts xs ys, pattern_merge pat pat'')
Andreas@68155
   162
     end
Andreas@68155
   163
     
Andreas@68155
   164
fun pattern_intersect Wildcard _ = Wildcard
Andreas@68155
   165
  | pattern_intersect Value pat2 = pat2
Andreas@68155
   166
  | pattern_intersect (Split _) Wildcard = Wildcard
Andreas@68155
   167
  | pattern_intersect (Split (n, xs', pat1)) Value =
Andreas@68155
   168
    Split (n,
Andreas@68155
   169
      map (apsnd (fn pat1 => pattern_intersect pat1 Value)) xs',
Andreas@68155
   170
      pattern_intersect pat1 Value)
Andreas@68155
   171
  | pattern_intersect (Split (n, xs', pat1)) (Split (m, ys, pat2)) =
Andreas@68155
   172
    Split (if n <= 0 then m else n, 
Andreas@68155
   173
      intersect_consts xs' ys pat1 pat2,
Andreas@68155
   174
      pattern_intersect pat1 pat2)
Andreas@68155
   175
and
Andreas@68155
   176
    intersect_consts xs [] _ default2 = map (apsnd (fn pat => pattern_intersect pat default2)) xs
Andreas@68155
   177
  | intersect_consts xs ((ctr, pat2) :: ys) default1 default2 = case find_ctr ctr xs of
Andreas@68155
   178
    (SOME (ctr, pat1), xs') => 
Andreas@68155
   179
      (ctr, pattern_merge (pattern_merge (pattern_intersect pat1 pat2) (pattern_intersect default1 pat2))
Andreas@68155
   180
              (pattern_intersect pat1 default2)) ::
Andreas@68155
   181
      intersect_consts xs' ys default1 default2
Andreas@68155
   182
    | (NONE, xs') => (ctr, pattern_intersect default1 pat2) :: (intersect_consts xs' ys default1 default2)
Andreas@68155
   183
        
Andreas@68155
   184
fun pattern_lookup _ Wildcard = Wildcard
Andreas@68155
   185
  | pattern_lookup _ Value = Value
Andreas@68155
   186
  | pattern_lookup [] (Split (n, xs, pat)) = 
Andreas@68155
   187
    Split (n, map (apsnd (pattern_lookup [])) xs, pattern_lookup [] pat)
Andreas@68155
   188
  | pattern_lookup (term :: terms) (Split (n, xs, pat)) =
Andreas@68155
   189
  let
Andreas@68155
   190
    val (ctr, args) = strip_comb term
Andreas@68155
   191
    fun map_ctr (term, pat) =
Andreas@68155
   192
      let
Andreas@68155
   193
        val args = term |> dest_Const |> snd |> binder_types |> map (fn T => Free ("x", T))
Andreas@68155
   194
      in
Andreas@68155
   195
        pattern_lookup args pat
Andreas@68155
   196
      end
Andreas@68155
   197
  in
Andreas@68155
   198
    if is_Const ctr then
Andreas@68155
   199
       case find_ctr ctr xs of (SOME (_, pat'), _) => 
Andreas@68155
   200
           pattern_lookup terms (pattern_merge (pattern_lookup args pat') (pattern_lookup [] pat))
Andreas@68155
   201
         | (NONE, _) => pattern_lookup terms pat
Andreas@68155
   202
    else if length xs < n orelse n <= 0 then
Andreas@68155
   203
      pattern_lookup terms pat
Andreas@68155
   204
    else pattern_lookup terms
Andreas@68155
   205
      (pattern_merge
Andreas@68155
   206
        (fold pattern_intersect (map map_ctr (tl xs)) (map_ctr (hd xs)))
Andreas@68155
   207
        (pattern_lookup [] pat))
Andreas@68155
   208
  end;
Andreas@68155
   209
Andreas@68155
   210
fun pattern_contains terms pat = case pattern_lookup terms pat of
Andreas@68155
   211
    Wildcard => false
Andreas@68155
   212
  | Value => true
Andreas@68155
   213
  | Split _ => raise Match;
Andreas@68155
   214
Andreas@68155
   215
fun pattern_create _ [] = Wildcard
Andreas@68155
   216
  | pattern_create ctr_count (term :: terms) = 
Andreas@68155
   217
    let
Andreas@68155
   218
      val (ctr, args) = strip_comb term
Andreas@68155
   219
    in
Andreas@68155
   220
      if is_Const ctr then
Andreas@68155
   221
        Split (ctr_count ctr, [(ctr, pattern_create ctr_count (args @ terms))], Wildcard)
Andreas@68155
   222
      else Split (0, [], pattern_create ctr_count terms)
Andreas@68155
   223
    end;
Andreas@68155
   224
Andreas@68155
   225
fun pattern_insert ctr_count terms pat =
Andreas@68155
   226
  let
Andreas@68155
   227
    fun new_pattern terms = pattern_insert ctr_count terms (pattern_create ctr_count terms)
Andreas@68155
   228
    fun aux _ false Wildcard = Wildcard
Andreas@68155
   229
      | aux terms true Wildcard = if null terms then Value else new_pattern terms
Andreas@68155
   230
      | aux _ _ Value = Value
Andreas@68155
   231
      | aux terms modify (Split (n, xs', pat)) =
Andreas@68155
   232
      let
Andreas@68155
   233
        val unmodified = (n, map (apsnd (aux [] false)) xs', aux [] false pat)
Andreas@68155
   234
      in case terms of [] => Split unmodified
Andreas@68155
   235
        | term :: terms =>
Andreas@68155
   236
        let
Andreas@68155
   237
          val (ctr, args) = strip_comb term
Andreas@68155
   238
          val (m, ys, pat') = unmodified
Andreas@68155
   239
        in
Andreas@68155
   240
          if is_Const ctr
Andreas@68155
   241
            then case find_ctr ctr xs' of
Andreas@68155
   242
              (SOME (ctr, pat''), xs) =>
Andreas@68155
   243
                Split (m, (ctr, aux (args @ terms) modify pat'') :: map (apsnd (aux [] false)) xs, pat')
Andreas@68155
   244
              | (NONE, _) => if modify
Andreas@68155
   245
                then if m <= 0
Andreas@68155
   246
                  then Split (ctr_count ctr, (ctr, new_pattern (args @ terms)) :: ys, pat')
Andreas@68155
   247
                  else Split (m, (ctr, new_pattern (args @ terms)) :: ys, pat')
Andreas@68155
   248
                else Split unmodified
Andreas@68155
   249
            else Split (m, ys, aux terms modify pat)
Andreas@68155
   250
        end
Andreas@68155
   251
      end
Andreas@68155
   252
  in
Andreas@68155
   253
    aux terms true pat
Andreas@68155
   254
  end;
Andreas@68155
   255
Andreas@68155
   256
val pattern_empty = Wildcard;
Andreas@68155
   257
Andreas@68155
   258
fun replace_frees lhss rhss typ_list ctxt =
Andreas@68155
   259
  let
Andreas@68155
   260
    fun replace_frees_once (lhs, rhs) ctxt =
Andreas@68155
   261
      let
Andreas@68155
   262
        val add_frees_list = fold_rev Term.add_frees
Andreas@68155
   263
        val frees = add_frees_list lhs []
Andreas@68155
   264
        val (new_frees, ctxt1) = (Ctr_Sugar_Util.mk_Frees "x" (map snd frees) ctxt)
Andreas@68155
   265
        val (new_frees1, ctxt2) =
Andreas@68155
   266
          let
Andreas@68155
   267
            val (dest_frees, types) = split_list (map dest_Free new_frees)
Andreas@68155
   268
            val (new_frees, ctxt2) = Variable.variant_fixes dest_frees ctxt1
Andreas@68155
   269
          in
Andreas@68155
   270
            (map Free (new_frees ~~ types), ctxt2)
Andreas@68155
   271
          end
Andreas@68155
   272
        val dict = frees ~~ new_frees1
Andreas@68155
   273
        fun free_map_fun (s, T) =
Andreas@68155
   274
          case AList.lookup (op =) dict (s, T) of
Andreas@68155
   275
              NONE => Free (s, T)
Andreas@68155
   276
            | SOME x => x
Andreas@68155
   277
        val map_fun = fold_term Const free_map_fun Var Bound Abs (op $)
Andreas@68155
   278
      in
Andreas@68155
   279
        ((map map_fun lhs, map_fun rhs), ctxt2)
Andreas@68155
   280
      end
Andreas@68155
   281
Andreas@68155
   282
    fun variant_fixes (def_frees, ctxt) =
Andreas@68155
   283
      let
Andreas@68155
   284
        val (dest_frees, types) = split_list (map dest_Free def_frees)
Andreas@68155
   285
        val (def_frees, ctxt1) = Variable.variant_fixes dest_frees ctxt
Andreas@68155
   286
      in
Andreas@68155
   287
        (map Free (def_frees ~~ types), ctxt1)
Andreas@68155
   288
      end
Andreas@68155
   289
    val (def_frees, ctxt1) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" typ_list ctxt)
Andreas@68155
   290
    val (rhs_frees, ctxt2) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" typ_list ctxt1)
Andreas@68155
   291
    val (case_args, ctxt3) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x"
Andreas@68155
   292
      (map fastype_of (hd lhss)) ctxt2)
Andreas@68155
   293
    val (new_terms1, ctxt4) = fold_map replace_frees_once (lhss ~~ rhss) ctxt3
Andreas@68155
   294
    val (lhss1, rhss1) = split_list new_terms1
Andreas@68155
   295
  in
Andreas@68155
   296
    (lhss1, rhss1, def_frees ~~ rhs_frees, case_args, ctxt4)
Andreas@68155
   297
  end;
Andreas@68155
   298
Andreas@68155
   299
fun add_names_in_type (Type (name, Ts)) = 
Andreas@68155
   300
    List.foldr (op o) (Symtab.update (name, ())) (map add_names_in_type Ts)
Andreas@68155
   301
  | add_names_in_type (TFree _) = I
Andreas@68155
   302
  | add_names_in_type (TVar _) = I
Andreas@68155
   303
Andreas@68155
   304
fun add_names_in_term (Const (_, T)) = add_names_in_type T
Andreas@68155
   305
  | add_names_in_term (Free (_, T)) = add_names_in_type T
Andreas@68155
   306
  | add_names_in_term (Var (_, T)) = add_names_in_type T
Andreas@68155
   307
  | add_names_in_term (Bound _) = I
Andreas@68155
   308
  | add_names_in_term (Abs (_, T, body)) =
Andreas@68155
   309
    add_names_in_type T o add_names_in_term body
Andreas@68155
   310
  | add_names_in_term (t1 $ t2) = add_names_in_term t1 o add_names_in_term t2
Andreas@68155
   311
Andreas@68155
   312
fun add_type_names terms =
Andreas@68155
   313
  fold (fn term => fn f => add_names_in_term term o f) terms I
Andreas@68155
   314
Andreas@68155
   315
fun get_split_theorems ctxt =
Andreas@68155
   316
  Symtab.keys
Andreas@68155
   317
  #> map_filter (Ctr_Sugar.ctr_sugar_of ctxt)
Andreas@68155
   318
  #> map #split;
Andreas@68155
   319
Andreas@68155
   320
fun match (Const (s1, _)) (Const (s2, _)) = if s1 = s2 then SOME I else NONE
Andreas@68155
   321
  | match (Free y) x = SOME (fn z => if z = Free y then x else z)
Andreas@68155
   322
  | match (pat1 $ pattern2) (t1 $ t2) =
Andreas@68155
   323
    (case (match pat1 t1, match pattern2 t2) of
Andreas@68155
   324
       (SOME f, SOME g) => SOME (f o g)
Andreas@68155
   325
       | _ => NONE
Andreas@68155
   326
     )
Andreas@68155
   327
  | match _ _ = NONE;
Andreas@68155
   328
Andreas@68155
   329
fun match_all patterns terms =
Andreas@68155
   330
  let
Andreas@68155
   331
    fun combine _ NONE = NONE
Andreas@68155
   332
      | combine (f_opt, f_opt') (SOME g) = 
Andreas@68155
   333
        case match f_opt f_opt' of SOME f => SOME (f o g) | _ => NONE
Andreas@68155
   334
  in
Andreas@68155
   335
    fold_rev combine (patterns ~~ terms) (SOME I)
Andreas@68155
   336
  end
Andreas@68155
   337
Andreas@68155
   338
fun matches (Const (s1, _)) (Const (s2, _)) = s1 = s2
Andreas@68155
   339
  | matches (Free _) _ = true 
Andreas@68155
   340
  | matches (pat1 $ pat2) (t1 $ t2) = matches pat1 t1 andalso matches pat2 t2
Andreas@68155
   341
  | matches _ _ = false;
Andreas@68155
   342
fun matches_all patterns terms = forall (uncurry matches) (patterns ~~ terms)
Andreas@68155
   343
Andreas@68155
   344
fun terms_to_case_at ctr_count ctxt (fun_t : term) (default_lhs : term list)
Andreas@68155
   345
    (pos, (lazy_case_arg, rhs_free))
Andreas@68155
   346
    ((lhss : term list list), (rhss : term list), type_name_fun) =
Andreas@68155
   347
  let
Andreas@68155
   348
    fun abort t =
Andreas@68155
   349
      let
Andreas@68155
   350
        val fun_name = head_of t |> dest_Const |> fst
Andreas@68155
   351
        val msg = "Missing pattern in " ^ fun_name ^ "."
Andreas@68155
   352
      in
Andreas@68155
   353
        mk_abort msg t
Andreas@68155
   354
      end;
Andreas@68155
   355
Andreas@68155
   356
    (* Step 1 : Eliminate lazy pattern *)
Andreas@68155
   357
    fun replace_pat_at (n, tcos) pat pats =
Andreas@68155
   358
      let
Andreas@68155
   359
        fun map_at _ _ [] = raise Empty
Andreas@68155
   360
          | map_at n f (x :: xs) = if n > 0
Andreas@68155
   361
            then apfst (cons x) (map_at (n - 1) f xs)
Andreas@68155
   362
            else apfst (fn x => x :: xs) (f x)
Andreas@68155
   363
        fun replace [] pat term = (pat, term)
Andreas@68155
   364
          | replace ((s1, n) :: tcos) pat term =
Andreas@68155
   365
            let
Andreas@68155
   366
              val (ctr, args) = strip_comb term
Andreas@68155
   367
            in
Andreas@68155
   368
              case ctr of Const (s2, _) =>
Andreas@68155
   369
                  if s1 = s2
Andreas@68155
   370
                  then apfst (pair ctr #> list_comb) (map_at n (replace tcos pat) args)
Andreas@68155
   371
                  else (term, rhs_free)
Andreas@68155
   372
                | _ => (term, rhs_free)
Andreas@68155
   373
            end
Andreas@68155
   374
        val (part1, (old_pat, part2)) = chop n pats ||> (fn xs => (hd xs, tl xs))
Andreas@68155
   375
        val (new_pat, old_pat1) = replace tcos pat old_pat
Andreas@68155
   376
      in
Andreas@68155
   377
        (part1 @ [new_pat] @ part2, old_pat1)
Andreas@68155
   378
      end                               
Andreas@68155
   379
    val (lhss1, lazy_pats) = map (replace_pat_at pos lazy_case_arg) lhss
Andreas@68155
   380
      |> split_list
Andreas@68155
   381
Andreas@68155
   382
    (* Step 2 : Split patterns *)
Andreas@68155
   383
    fun split equs =
Andreas@68155
   384
      let
Andreas@68155
   385
        fun merge_pattern (Const (s1, T1), Const (s2, _)) =
Andreas@68155
   386
            if s1 = s2 then SOME (Const (s1, T1)) else NONE
Andreas@68155
   387
          | merge_pattern (t, Free _) = SOME t
Andreas@68155
   388
          | merge_pattern (Free _, t) = SOME t
Andreas@68155
   389
          | merge_pattern (t1l $ t1r, t2l $ t2r) =
Andreas@68155
   390
            (case (merge_pattern (t1l, t2l), merge_pattern (t1r, t2r)) of
Andreas@68155
   391
              (SOME t1, SOME t2) => SOME (t1 $ t2)
Andreas@68155
   392
              | _ => NONE)
Andreas@68155
   393
          | merge_pattern _ = NONE
Andreas@68155
   394
        fun merge_patterns pats1 pats2 = case (pats1, pats2) of
Andreas@68155
   395
          ([], []) => SOME []
Andreas@68155
   396
          | (x :: xs, y :: ys) =>
Andreas@68155
   397
            (case (merge_pattern (x, y), merge_patterns xs ys) of
Andreas@68155
   398
              (SOME x, SOME xs) => SOME (x :: xs)
Andreas@68155
   399
              | _ => NONE
Andreas@68155
   400
            )
Andreas@68155
   401
          | _ => raise Match
Andreas@68155
   402
        fun merge_insert ((lhs1, case_pat), _) [] =
Andreas@68155
   403
            [(lhs1, pattern_empty |> pattern_insert ctr_count [case_pat])]
Andreas@68155
   404
          | merge_insert ((lhs1, case_pat), rhs) ((lhs2, pat) :: pats) =
Andreas@68155
   405
            let
Andreas@68155
   406
              val pats = merge_insert ((lhs1, case_pat), rhs) pats
Andreas@68155
   407
              val (first_equ_needed, new_lhs) = case merge_patterns lhs1 lhs2 of
Andreas@68155
   408
                SOME new_lhs => (not (pattern_contains [case_pat] pat), new_lhs)
Andreas@68155
   409
                | NONE => (false, lhs2)
Andreas@68155
   410
              val second_equ_needed = not (matches_all lhs1 lhs2)
Andreas@68155
   411
                orelse not first_equ_needed
Andreas@68155
   412
              val first_equ = if first_equ_needed
Andreas@68155
   413
                then [(new_lhs, pattern_insert ctr_count [case_pat] pat)]
Andreas@68155
   414
                else []
Andreas@68155
   415
              val second_equ = if second_equ_needed
Andreas@68155
   416
                then [(lhs2, pat)]
Andreas@68155
   417
                else []
Andreas@68155
   418
            in
Andreas@68155
   419
              first_equ @ second_equ @ pats
Andreas@68155
   420
            end
Andreas@68155
   421
        in
Andreas@68155
   422
          (fold merge_insert equs []
Andreas@68155
   423
            |> split_list
Andreas@68155
   424
            |> fst) @ [default_lhs]
Andreas@68155
   425
        end
Andreas@68155
   426
    val lhss2 = split ((lhss1 ~~ lazy_pats) ~~ rhss)
Andreas@68155
   427
Andreas@68155
   428
    (* Step 3 : Remove redundant patterns *)
Andreas@68155
   429
    fun remove_redundant_lhs lhss =
Andreas@68155
   430
      let
Andreas@68155
   431
        fun f lhs pat = if pattern_contains lhs pat
Andreas@68155
   432
          then ((lhs, false), pat)
Andreas@68155
   433
          else ((lhs, true), pattern_insert ctr_count lhs pat)
Andreas@68155
   434
      in
Andreas@68155
   435
        fold_map f lhss pattern_empty
Andreas@68155
   436
        |> fst
Andreas@68155
   437
        |> filter snd
Andreas@68155
   438
        |> map fst
Andreas@68155
   439
      end
Andreas@68155
   440
    fun remove_redundant_rhs rhss =
Andreas@68155
   441
      let
Andreas@68155
   442
        fun f (lhs, rhs) pat = if pattern_contains [lhs] pat
Andreas@68155
   443
          then (((lhs, rhs), false), pat)
Andreas@68155
   444
          else (((lhs, rhs), true), pattern_insert ctr_count [lhs] pat)
Andreas@68155
   445
      in
Andreas@68155
   446
        map fst (filter snd (fold_map f rhss pattern_empty |> fst))
Andreas@68155
   447
      end
Andreas@68155
   448
    val lhss3 = remove_redundant_lhs lhss2
Andreas@68155
   449
Andreas@68155
   450
    (* Step 4 : Compute right hand side *)
Andreas@68155
   451
    fun subs_fun f = fold_term
Andreas@68155
   452
      Const
Andreas@68155
   453
      (f o Free)
Andreas@68155
   454
      Var
Andreas@68155
   455
      Bound
Andreas@68155
   456
      Abs
Andreas@68155
   457
      (fn (x, y) => f x $ f y)
Andreas@68155
   458
    fun find_rhss lhs =
Andreas@68155
   459
      let
Andreas@68155
   460
        fun f (lhs1, (pat, rhs)) = 
Andreas@68155
   461
          case match_all lhs1 lhs of NONE => NONE
Andreas@68155
   462
          | SOME f => SOME (pat, subs_fun f rhs)
Andreas@68155
   463
      in
Andreas@68155
   464
        remove_redundant_rhs
Andreas@68155
   465
          (map_filter f (lhss1 ~~ (lazy_pats ~~ rhss)) @
Andreas@68155
   466
            [(lazy_case_arg, list_comb (fun_t, lhs) |> abort)]
Andreas@68155
   467
          )
Andreas@68155
   468
      end
Andreas@68155
   469
Andreas@68155
   470
    (* Step 5 : make_case of right hand side *)
Andreas@68155
   471
    fun make_case ctxt case_arg cases = case cases of
Andreas@68155
   472
      [(Free x, rhs)] => subs_fun (fn y => if y = Free x then case_arg else y) rhs
Andreas@68155
   473
      | _ => Case_Translation.make_case
Andreas@68155
   474
        ctxt
Andreas@68155
   475
        Case_Translation.Warning
Andreas@68155
   476
        Name.context
Andreas@68155
   477
        case_arg
Andreas@68155
   478
        cases
Andreas@68155
   479
    val type_name_fun = add_type_names lazy_pats o type_name_fun
Andreas@68155
   480
    val rhss3 = map ((make_case ctxt lazy_case_arg) o find_rhss) lhss3
Andreas@68155
   481
  in
Andreas@68155
   482
    (lhss3, rhss3, type_name_fun)
Andreas@68155
   483
  end;
Andreas@68155
   484
Andreas@68155
   485
fun terms_to_case ctxt ctr_count (head : term) (lhss : term list list)
Andreas@68155
   486
    (rhss : term list) (typ_list : typ list) (poss : (int * (string * int) list) list) =
Andreas@68155
   487
  let
Andreas@68155
   488
    val (lhss1, rhss1, def_frees, case_args, ctxt1) = replace_frees lhss rhss typ_list ctxt
Andreas@68155
   489
    val exec_list = poss ~~ def_frees
Andreas@68155
   490
    val (lhss2, rhss2, type_name_fun) = fold_rev
Andreas@68155
   491
      (terms_to_case_at ctr_count ctxt1 head case_args) exec_list (lhss1, rhss1, I)
Andreas@68155
   492
    fun make_eq_term (lhss, rhs) = (list_comb (head, lhss), rhs)
Andreas@68155
   493
      |> HOLogic.mk_eq
Andreas@68155
   494
      |> HOLogic.mk_Trueprop
Andreas@68155
   495
  in
Andreas@68155
   496
    (map make_eq_term (lhss2 ~~ rhss2),
Andreas@68155
   497
      get_split_theorems ctxt1 (type_name_fun Symtab.empty),
Andreas@68155
   498
      ctxt1)
Andreas@68155
   499
  end;
Andreas@68155
   500
Andreas@69568
   501
Andreas@69568
   502
fun build_case_t elimination_strategy ctr_count head lhss rhss ctxt =
Andreas@68155
   503
  let
Andreas@68155
   504
    val num_eqs = length lhss
Andreas@68155
   505
    val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
Andreas@68155
   506
      else raise Fail
Andreas@68155
   507
        ("expected same number of left-hand sides as right-hand sides\n"
Andreas@68155
   508
          ^ "and at least one equation")
Andreas@68155
   509
    val n = length (hd lhss)
Andreas@68155
   510
    val _ = if forall (fn m => length m = n) lhss then ()
Andreas@68155
   511
      else raise Fail "expected equal number of arguments"
Andreas@68155
   512
Andreas@69568
   513
    fun to_coordinates (n, ts) = 
Andreas@69568
   514
      case elimination_strategy ctxt ts of
Andreas@69568
   515
          [] => NONE
Andreas@69568
   516
        | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
Andreas@68155
   517
    fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
Andreas@68155
   518
    val (typ_list, poss) = lhss
Andreas@68155
   519
      |> Ctr_Sugar_Util.transpose
Andreas@68155
   520
      |> map_index to_coordinates
wenzelm@68301
   521
      |> map_filter (Option.map add_T)
Andreas@68155
   522
      |> flat
Andreas@69568
   523
      |> split_list
Andreas@68155
   524
  in
Andreas@68155
   525
    if null poss then ([], [], ctxt)
Andreas@68155
   526
    else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss
Andreas@68155
   527
  end;
Andreas@68155
   528
Andreas@68155
   529
fun tac ctxt {splits, intros, defs} =
Andreas@68155
   530
  let
Andreas@68155
   531
    val split_and_subst = 
Andreas@68155
   532
      split_tac ctxt splits 
Andreas@68155
   533
      THEN' REPEAT_ALL_NEW (
Andreas@68155
   534
        resolve_tac ctxt [@{thm conjI}, @{thm allI}]
Andreas@68155
   535
        ORELSE'
Andreas@68155
   536
        (resolve_tac ctxt [@{thm impI}] THEN' hyp_subst_tac_thin true ctxt))
Andreas@68155
   537
  in
Andreas@68155
   538
    (REPEAT_ALL_NEW split_and_subst ORELSE' K all_tac)
Andreas@68155
   539
    THEN' (K (Local_Defs.unfold_tac ctxt [@{thm missing_pattern_match_def}]))
Andreas@68155
   540
    THEN' (K (Local_Defs.unfold_tac ctxt defs))
Andreas@68155
   541
    THEN_ALL_NEW (SOLVED' (resolve_tac ctxt (@{thm refl} :: intros)))
Andreas@68155
   542
  end;
Andreas@68155
   543
Andreas@68155
   544
fun to_case _ _ _ [] = NONE
Andreas@68155
   545
  | to_case ctxt replace_ctr ctr_count ths =
Andreas@68155
   546
    let
Andreas@68155
   547
      val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
Andreas@68155
   548
      fun import [] ctxt = ([], ctxt)
Andreas@68155
   549
        | import (thm :: thms) ctxt =
Andreas@68155
   550
          let
Andreas@68155
   551
            val fun_ct = strip_eq #> fst #> head_of #> Logic.mk_term #> Thm.cterm_of ctxt
Andreas@68155
   552
            val ct = fun_ct thm
Andreas@68155
   553
            val cts = map fun_ct thms
Andreas@68155
   554
            val pairs = map (fn s => (s,ct)) cts
Andreas@68155
   555
            val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
Andreas@68155
   556
          in
Andreas@68155
   557
            Variable.import true (thm :: thms') ctxt |> apfst snd
Andreas@68155
   558
          end
Andreas@68155
   559
Andreas@68155
   560
      val (iths, ctxt') = import ths ctxt
Andreas@68155
   561
      val head = hd iths |> strip_eq |> fst |> head_of
Andreas@68155
   562
      val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
Andreas@68155
   563
Andreas@68155
   564
      fun hide_rhs ((pat, rhs), name) lthy =
Andreas@68155
   565
        let
Andreas@68155
   566
          val frees = fold Term.add_frees pat []
Andreas@68155
   567
          val abs_rhs = fold absfree frees rhs
Andreas@68155
   568
          val (f, def, lthy') = case lthy
Andreas@68155
   569
            |> Local_Defs.define [((Binding.name name, NoSyn), (Binding.empty_atts, abs_rhs))] of
Andreas@68155
   570
              ([(f, (_, def))], lthy') => (f, def, lthy')
Andreas@68155
   571
              | _ => raise Match
Andreas@68155
   572
        in
Andreas@68155
   573
          ((list_comb (f, map Free (rev frees)), def), lthy')
Andreas@68155
   574
        end
Andreas@68155
   575
Andreas@68155
   576
      val rhs_names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs)
Andreas@68155
   577
      val ((def_ts, def_thms), ctxt2) =
Andreas@68155
   578
        fold_map hide_rhs (eqs ~~ rhs_names) ctxt' |> apfst split_list
Andreas@68155
   579
      val (ts, split_thms, ctxt3) = build_case_t replace_ctr ctr_count head
Andreas@68155
   580
        (map fst eqs) def_ts ctxt2
Andreas@68155
   581
      fun mk_thm t = Goal.prove ctxt3 [] [] t
Andreas@68155
   582
          (fn {context=ctxt, ...} => tac ctxt {splits=split_thms, intros=ths, defs=def_thms} 1)
Andreas@68155
   583
    in
Andreas@68155
   584
      if null ts then NONE
Andreas@68155
   585
      else
Andreas@68155
   586
        ts
Andreas@68155
   587
        |> map mk_thm
Andreas@68155
   588
        |> Proof_Context.export ctxt3 ctxt
Andreas@68155
   589
        |> map (Goal.norm_result ctxt)
Andreas@68155
   590
        |> SOME
Andreas@68155
   591
    end;
Andreas@68155
   592
Andreas@68155
   593
end