src/HOL/Library/case_converter.ML
author wenzelm
Tue May 15 13:57:39 2018 +0200 (16 months ago)
changeset 68189 6163c90694ef
parent 68155 8b50f29a1992
child 68301 fb5653a7a879
permissions -rw-r--r--
tuned headers;
Andreas@68155
     1
(* Author: Pascal Stoop, ETH Zurich
Andreas@68155
     2
   Author: Andreas Lochbihler, Digital Asset *)
Andreas@68155
     3
Andreas@68155
     4
signature CASE_CONVERTER =
Andreas@68155
     5
sig
Andreas@68155
     6
  val to_case: Proof.context -> (string * string -> bool) -> (string * typ -> int) ->
Andreas@68155
     7
    thm list -> thm list option
Andreas@68155
     8
end;
Andreas@68155
     9
Andreas@68155
    10
structure Case_Converter : CASE_CONVERTER =
Andreas@68155
    11
struct
Andreas@68155
    12
Andreas@68155
    13
fun lookup_remove _ _ [] = (NONE, [])
Andreas@68155
    14
  | lookup_remove eq k ((k', v) :: kvs) =
Andreas@68155
    15
    if eq (k, k') then (SOME (k', v), kvs)
Andreas@68155
    16
    else apsnd (cons (k', v)) (lookup_remove eq k kvs)
Andreas@68155
    17
Andreas@68155
    18
fun map_option _ NONE = NONE
Andreas@68155
    19
  | map_option f (SOME x) = SOME (f x)
Andreas@68155
    20
Andreas@68155
    21
fun mk_abort msg t =
Andreas@68155
    22
  let 
Andreas@68155
    23
    val T = fastype_of t
Andreas@68155
    24
    val abort = Const (@{const_name missing_pattern_match}, HOLogic.literalT --> (HOLogic.unitT --> T) --> T)
Andreas@68155
    25
  in
Andreas@68155
    26
    abort $ HOLogic.mk_literal msg $ absdummy HOLogic.unitT t
Andreas@68155
    27
  end
Andreas@68155
    28
Andreas@68155
    29
(* fold_term : (string * typ -> 'a) ->
Andreas@68155
    30
               (string * typ -> 'a) ->
Andreas@68155
    31
               (indexname * typ -> 'a) ->
Andreas@68155
    32
               (int -> 'a) ->
Andreas@68155
    33
               (string * typ * 'a -> 'a) ->
Andreas@68155
    34
               ('a * 'a -> 'a) ->
Andreas@68155
    35
               term ->
Andreas@68155
    36
               'a *)
Andreas@68155
    37
fun fold_term const_fun free_fun var_fun bound_fun abs_fun dollar_fun term =
Andreas@68155
    38
  let
Andreas@68155
    39
    fun go x = case x of
Andreas@68155
    40
      Const (s, T) => const_fun (s, T)
Andreas@68155
    41
      | Free (s, T) => free_fun (s, T)
Andreas@68155
    42
      | Var (i, T) => var_fun (i, T)
Andreas@68155
    43
      | Bound n => bound_fun n
Andreas@68155
    44
      | Abs (s, T, term) => abs_fun (s, T, go term)
Andreas@68155
    45
      | term1 $ term2 => dollar_fun (go term1, go term2)
Andreas@68155
    46
  in
Andreas@68155
    47
    go term
Andreas@68155
    48
  end;
Andreas@68155
    49
Andreas@68155
    50
datatype term_coordinate = End of typ
Andreas@68155
    51
  | Coordinate of (string * (int * term_coordinate)) list;
Andreas@68155
    52
Andreas@68155
    53
fun term_coordinate_merge (End T) _ = End T
Andreas@68155
    54
  | term_coordinate_merge _ (End T) = End T
Andreas@68155
    55
  | term_coordinate_merge (Coordinate xs) (Coordinate ys) =
Andreas@68155
    56
  let
Andreas@68155
    57
    fun merge_consts xs [] = xs
Andreas@68155
    58
      | merge_consts xs ((s1, (n, y)) :: ys) = 
Andreas@68155
    59
        case List.partition (fn (s2, (m, _)) => s1 = s2 andalso n = m) xs of
Andreas@68155
    60
            ([], xs') => (s1, (n, y)) :: (merge_consts xs' ys)
Andreas@68155
    61
          | ((_, (_, x)) :: _, xs') => (s1, (n, term_coordinate_merge x y)) :: (merge_consts xs' ys)
Andreas@68155
    62
  in
Andreas@68155
    63
    Coordinate (merge_consts xs ys)
Andreas@68155
    64
  end;
Andreas@68155
    65
Andreas@68155
    66
fun term_to_coordinates P term = 
Andreas@68155
    67
  let
Andreas@68155
    68
    val (ctr, args) = strip_comb term
Andreas@68155
    69
  in
Andreas@68155
    70
    case ctr of Const (s, T) =>
Andreas@68155
    71
      if P (body_type T |> dest_Type |> fst, s)
Andreas@68155
    72
      then SOME (End (body_type T))
Andreas@68155
    73
      else
Andreas@68155
    74
        let
Andreas@68155
    75
          fun f (i, t) = term_to_coordinates P t |> map_option (pair i)
Andreas@68155
    76
          val tcos = map_filter I (map_index f args)
Andreas@68155
    77
        in
Andreas@68155
    78
          if null tcos then NONE
Andreas@68155
    79
          else SOME (Coordinate (map (pair s) tcos))
Andreas@68155
    80
        end
Andreas@68155
    81
    | _ => NONE
Andreas@68155
    82
  end;
Andreas@68155
    83
Andreas@68155
    84
fun coordinates_to_list (End x) = [(x, [])]
Andreas@68155
    85
  | coordinates_to_list (Coordinate xs) = 
Andreas@68155
    86
  let
Andreas@68155
    87
    fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
Andreas@68155
    88
  in flat (map f xs) end;
Andreas@68155
    89
Andreas@68155
    90
Andreas@68155
    91
(* AL: TODO: change from term to const_name *)
Andreas@68155
    92
fun find_ctr ctr1 xs =
Andreas@68155
    93
  let
Andreas@68155
    94
    val const_name = fst o dest_Const
Andreas@68155
    95
    fun const_equal (ctr1, ctr2) = const_name ctr1 = const_name ctr2
Andreas@68155
    96
  in
Andreas@68155
    97
    lookup_remove const_equal ctr1 xs
Andreas@68155
    98
  end;
Andreas@68155
    99
Andreas@68155
   100
datatype pattern 
Andreas@68155
   101
  = Wildcard
Andreas@68155
   102
  | Value
Andreas@68155
   103
  | Split of int * (term * pattern) list * pattern;
Andreas@68155
   104
Andreas@68155
   105
fun pattern_merge Wildcard pat' = pat'
Andreas@68155
   106
  | pattern_merge Value _ = Value
Andreas@68155
   107
  | pattern_merge (Split (n, xs, pat)) Wildcard =
Andreas@68155
   108
    Split (n, map (apsnd (fn pat'' => pattern_merge pat'' Wildcard)) xs, pattern_merge pat Wildcard)
Andreas@68155
   109
  | pattern_merge (Split _) Value = Value
Andreas@68155
   110
  | pattern_merge (Split (n, xs, pat)) (Split (m, ys, pat'')) =
Andreas@68155
   111
    let 
Andreas@68155
   112
      fun merge_consts xs [] = map (apsnd (fn pat => pattern_merge pat Wildcard)) xs
Andreas@68155
   113
        | merge_consts xs ((ctr, y) :: ys) =
Andreas@68155
   114
          (case find_ctr ctr xs of
Andreas@68155
   115
              (SOME (ctr, x), xs) => (ctr, pattern_merge x y) :: merge_consts xs ys
Andreas@68155
   116
            | (NONE, xs) => (ctr, y) :: merge_consts xs ys
Andreas@68155
   117
          )
Andreas@68155
   118
     in
Andreas@68155
   119
       Split (if n <= 0 then m else n, merge_consts xs ys, pattern_merge pat pat'')
Andreas@68155
   120
     end
Andreas@68155
   121
     
Andreas@68155
   122
fun pattern_intersect Wildcard _ = Wildcard
Andreas@68155
   123
  | pattern_intersect Value pat2 = pat2
Andreas@68155
   124
  | pattern_intersect (Split _) Wildcard = Wildcard
Andreas@68155
   125
  | pattern_intersect (Split (n, xs', pat1)) Value =
Andreas@68155
   126
    Split (n,
Andreas@68155
   127
      map (apsnd (fn pat1 => pattern_intersect pat1 Value)) xs',
Andreas@68155
   128
      pattern_intersect pat1 Value)
Andreas@68155
   129
  | pattern_intersect (Split (n, xs', pat1)) (Split (m, ys, pat2)) =
Andreas@68155
   130
    Split (if n <= 0 then m else n, 
Andreas@68155
   131
      intersect_consts xs' ys pat1 pat2,
Andreas@68155
   132
      pattern_intersect pat1 pat2)
Andreas@68155
   133
and
Andreas@68155
   134
    intersect_consts xs [] _ default2 = map (apsnd (fn pat => pattern_intersect pat default2)) xs
Andreas@68155
   135
  | intersect_consts xs ((ctr, pat2) :: ys) default1 default2 = case find_ctr ctr xs of
Andreas@68155
   136
    (SOME (ctr, pat1), xs') => 
Andreas@68155
   137
      (ctr, pattern_merge (pattern_merge (pattern_intersect pat1 pat2) (pattern_intersect default1 pat2))
Andreas@68155
   138
              (pattern_intersect pat1 default2)) ::
Andreas@68155
   139
      intersect_consts xs' ys default1 default2
Andreas@68155
   140
    | (NONE, xs') => (ctr, pattern_intersect default1 pat2) :: (intersect_consts xs' ys default1 default2)
Andreas@68155
   141
        
Andreas@68155
   142
fun pattern_lookup _ Wildcard = Wildcard
Andreas@68155
   143
  | pattern_lookup _ Value = Value
Andreas@68155
   144
  | pattern_lookup [] (Split (n, xs, pat)) = 
Andreas@68155
   145
    Split (n, map (apsnd (pattern_lookup [])) xs, pattern_lookup [] pat)
Andreas@68155
   146
  | pattern_lookup (term :: terms) (Split (n, xs, pat)) =
Andreas@68155
   147
  let
Andreas@68155
   148
    val (ctr, args) = strip_comb term
Andreas@68155
   149
    fun map_ctr (term, pat) =
Andreas@68155
   150
      let
Andreas@68155
   151
        val args = term |> dest_Const |> snd |> binder_types |> map (fn T => Free ("x", T))
Andreas@68155
   152
      in
Andreas@68155
   153
        pattern_lookup args pat
Andreas@68155
   154
      end
Andreas@68155
   155
  in
Andreas@68155
   156
    if is_Const ctr then
Andreas@68155
   157
       case find_ctr ctr xs of (SOME (_, pat'), _) => 
Andreas@68155
   158
           pattern_lookup terms (pattern_merge (pattern_lookup args pat') (pattern_lookup [] pat))
Andreas@68155
   159
         | (NONE, _) => pattern_lookup terms pat
Andreas@68155
   160
    else if length xs < n orelse n <= 0 then
Andreas@68155
   161
      pattern_lookup terms pat
Andreas@68155
   162
    else pattern_lookup terms
Andreas@68155
   163
      (pattern_merge
Andreas@68155
   164
        (fold pattern_intersect (map map_ctr (tl xs)) (map_ctr (hd xs)))
Andreas@68155
   165
        (pattern_lookup [] pat))
Andreas@68155
   166
  end;
Andreas@68155
   167
Andreas@68155
   168
fun pattern_contains terms pat = case pattern_lookup terms pat of
Andreas@68155
   169
    Wildcard => false
Andreas@68155
   170
  | Value => true
Andreas@68155
   171
  | Split _ => raise Match;
Andreas@68155
   172
Andreas@68155
   173
fun pattern_create _ [] = Wildcard
Andreas@68155
   174
  | pattern_create ctr_count (term :: terms) = 
Andreas@68155
   175
    let
Andreas@68155
   176
      val (ctr, args) = strip_comb term
Andreas@68155
   177
    in
Andreas@68155
   178
      if is_Const ctr then
Andreas@68155
   179
        Split (ctr_count ctr, [(ctr, pattern_create ctr_count (args @ terms))], Wildcard)
Andreas@68155
   180
      else Split (0, [], pattern_create ctr_count terms)
Andreas@68155
   181
    end;
Andreas@68155
   182
Andreas@68155
   183
fun pattern_insert ctr_count terms pat =
Andreas@68155
   184
  let
Andreas@68155
   185
    fun new_pattern terms = pattern_insert ctr_count terms (pattern_create ctr_count terms)
Andreas@68155
   186
    fun aux _ false Wildcard = Wildcard
Andreas@68155
   187
      | aux terms true Wildcard = if null terms then Value else new_pattern terms
Andreas@68155
   188
      | aux _ _ Value = Value
Andreas@68155
   189
      | aux terms modify (Split (n, xs', pat)) =
Andreas@68155
   190
      let
Andreas@68155
   191
        val unmodified = (n, map (apsnd (aux [] false)) xs', aux [] false pat)
Andreas@68155
   192
      in case terms of [] => Split unmodified
Andreas@68155
   193
        | term :: terms =>
Andreas@68155
   194
        let
Andreas@68155
   195
          val (ctr, args) = strip_comb term
Andreas@68155
   196
          val (m, ys, pat') = unmodified
Andreas@68155
   197
        in
Andreas@68155
   198
          if is_Const ctr
Andreas@68155
   199
            then case find_ctr ctr xs' of
Andreas@68155
   200
              (SOME (ctr, pat''), xs) =>
Andreas@68155
   201
                Split (m, (ctr, aux (args @ terms) modify pat'') :: map (apsnd (aux [] false)) xs, pat')
Andreas@68155
   202
              | (NONE, _) => if modify
Andreas@68155
   203
                then if m <= 0
Andreas@68155
   204
                  then Split (ctr_count ctr, (ctr, new_pattern (args @ terms)) :: ys, pat')
Andreas@68155
   205
                  else Split (m, (ctr, new_pattern (args @ terms)) :: ys, pat')
Andreas@68155
   206
                else Split unmodified
Andreas@68155
   207
            else Split (m, ys, aux terms modify pat)
Andreas@68155
   208
        end
Andreas@68155
   209
      end
Andreas@68155
   210
  in
Andreas@68155
   211
    aux terms true pat
Andreas@68155
   212
  end;
Andreas@68155
   213
Andreas@68155
   214
val pattern_empty = Wildcard;
Andreas@68155
   215
Andreas@68155
   216
fun replace_frees lhss rhss typ_list ctxt =
Andreas@68155
   217
  let
Andreas@68155
   218
    fun replace_frees_once (lhs, rhs) ctxt =
Andreas@68155
   219
      let
Andreas@68155
   220
        val add_frees_list = fold_rev Term.add_frees
Andreas@68155
   221
        val frees = add_frees_list lhs []
Andreas@68155
   222
        val (new_frees, ctxt1) = (Ctr_Sugar_Util.mk_Frees "x" (map snd frees) ctxt)
Andreas@68155
   223
        val (new_frees1, ctxt2) =
Andreas@68155
   224
          let
Andreas@68155
   225
            val (dest_frees, types) = split_list (map dest_Free new_frees)
Andreas@68155
   226
            val (new_frees, ctxt2) = Variable.variant_fixes dest_frees ctxt1
Andreas@68155
   227
          in
Andreas@68155
   228
            (map Free (new_frees ~~ types), ctxt2)
Andreas@68155
   229
          end
Andreas@68155
   230
        val dict = frees ~~ new_frees1
Andreas@68155
   231
        fun free_map_fun (s, T) =
Andreas@68155
   232
          case AList.lookup (op =) dict (s, T) of
Andreas@68155
   233
              NONE => Free (s, T)
Andreas@68155
   234
            | SOME x => x
Andreas@68155
   235
        val map_fun = fold_term Const free_map_fun Var Bound Abs (op $)
Andreas@68155
   236
      in
Andreas@68155
   237
        ((map map_fun lhs, map_fun rhs), ctxt2)
Andreas@68155
   238
      end
Andreas@68155
   239
Andreas@68155
   240
    fun variant_fixes (def_frees, ctxt) =
Andreas@68155
   241
      let
Andreas@68155
   242
        val (dest_frees, types) = split_list (map dest_Free def_frees)
Andreas@68155
   243
        val (def_frees, ctxt1) = Variable.variant_fixes dest_frees ctxt
Andreas@68155
   244
      in
Andreas@68155
   245
        (map Free (def_frees ~~ types), ctxt1)
Andreas@68155
   246
      end
Andreas@68155
   247
    val (def_frees, ctxt1) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" typ_list ctxt)
Andreas@68155
   248
    val (rhs_frees, ctxt2) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" typ_list ctxt1)
Andreas@68155
   249
    val (case_args, ctxt3) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x"
Andreas@68155
   250
      (map fastype_of (hd lhss)) ctxt2)
Andreas@68155
   251
    val (new_terms1, ctxt4) = fold_map replace_frees_once (lhss ~~ rhss) ctxt3
Andreas@68155
   252
    val (lhss1, rhss1) = split_list new_terms1
Andreas@68155
   253
  in
Andreas@68155
   254
    (lhss1, rhss1, def_frees ~~ rhs_frees, case_args, ctxt4)
Andreas@68155
   255
  end;
Andreas@68155
   256
Andreas@68155
   257
fun add_names_in_type (Type (name, Ts)) = 
Andreas@68155
   258
    List.foldr (op o) (Symtab.update (name, ())) (map add_names_in_type Ts)
Andreas@68155
   259
  | add_names_in_type (TFree _) = I
Andreas@68155
   260
  | add_names_in_type (TVar _) = I
Andreas@68155
   261
Andreas@68155
   262
fun add_names_in_term (Const (_, T)) = add_names_in_type T
Andreas@68155
   263
  | add_names_in_term (Free (_, T)) = add_names_in_type T
Andreas@68155
   264
  | add_names_in_term (Var (_, T)) = add_names_in_type T
Andreas@68155
   265
  | add_names_in_term (Bound _) = I
Andreas@68155
   266
  | add_names_in_term (Abs (_, T, body)) =
Andreas@68155
   267
    add_names_in_type T o add_names_in_term body
Andreas@68155
   268
  | add_names_in_term (t1 $ t2) = add_names_in_term t1 o add_names_in_term t2
Andreas@68155
   269
Andreas@68155
   270
fun add_type_names terms =
Andreas@68155
   271
  fold (fn term => fn f => add_names_in_term term o f) terms I
Andreas@68155
   272
Andreas@68155
   273
fun get_split_theorems ctxt =
Andreas@68155
   274
  Symtab.keys
Andreas@68155
   275
  #> map_filter (Ctr_Sugar.ctr_sugar_of ctxt)
Andreas@68155
   276
  #> map #split;
Andreas@68155
   277
Andreas@68155
   278
fun match (Const (s1, _)) (Const (s2, _)) = if s1 = s2 then SOME I else NONE
Andreas@68155
   279
  | match (Free y) x = SOME (fn z => if z = Free y then x else z)
Andreas@68155
   280
  | match (pat1 $ pattern2) (t1 $ t2) =
Andreas@68155
   281
    (case (match pat1 t1, match pattern2 t2) of
Andreas@68155
   282
       (SOME f, SOME g) => SOME (f o g)
Andreas@68155
   283
       | _ => NONE
Andreas@68155
   284
     )
Andreas@68155
   285
  | match _ _ = NONE;
Andreas@68155
   286
Andreas@68155
   287
fun match_all patterns terms =
Andreas@68155
   288
  let
Andreas@68155
   289
    fun combine _ NONE = NONE
Andreas@68155
   290
      | combine (f_opt, f_opt') (SOME g) = 
Andreas@68155
   291
        case match f_opt f_opt' of SOME f => SOME (f o g) | _ => NONE
Andreas@68155
   292
  in
Andreas@68155
   293
    fold_rev combine (patterns ~~ terms) (SOME I)
Andreas@68155
   294
  end
Andreas@68155
   295
Andreas@68155
   296
fun matches (Const (s1, _)) (Const (s2, _)) = s1 = s2
Andreas@68155
   297
  | matches (Free _) _ = true 
Andreas@68155
   298
  | matches (pat1 $ pat2) (t1 $ t2) = matches pat1 t1 andalso matches pat2 t2
Andreas@68155
   299
  | matches _ _ = false;
Andreas@68155
   300
fun matches_all patterns terms = forall (uncurry matches) (patterns ~~ terms)
Andreas@68155
   301
Andreas@68155
   302
fun terms_to_case_at ctr_count ctxt (fun_t : term) (default_lhs : term list)
Andreas@68155
   303
    (pos, (lazy_case_arg, rhs_free))
Andreas@68155
   304
    ((lhss : term list list), (rhss : term list), type_name_fun) =
Andreas@68155
   305
  let
Andreas@68155
   306
    fun abort t =
Andreas@68155
   307
      let
Andreas@68155
   308
        val fun_name = head_of t |> dest_Const |> fst
Andreas@68155
   309
        val msg = "Missing pattern in " ^ fun_name ^ "."
Andreas@68155
   310
      in
Andreas@68155
   311
        mk_abort msg t
Andreas@68155
   312
      end;
Andreas@68155
   313
Andreas@68155
   314
    (* Step 1 : Eliminate lazy pattern *)
Andreas@68155
   315
    fun replace_pat_at (n, tcos) pat pats =
Andreas@68155
   316
      let
Andreas@68155
   317
        fun map_at _ _ [] = raise Empty
Andreas@68155
   318
          | map_at n f (x :: xs) = if n > 0
Andreas@68155
   319
            then apfst (cons x) (map_at (n - 1) f xs)
Andreas@68155
   320
            else apfst (fn x => x :: xs) (f x)
Andreas@68155
   321
        fun replace [] pat term = (pat, term)
Andreas@68155
   322
          | replace ((s1, n) :: tcos) pat term =
Andreas@68155
   323
            let
Andreas@68155
   324
              val (ctr, args) = strip_comb term
Andreas@68155
   325
            in
Andreas@68155
   326
              case ctr of Const (s2, _) =>
Andreas@68155
   327
                  if s1 = s2
Andreas@68155
   328
                  then apfst (pair ctr #> list_comb) (map_at n (replace tcos pat) args)
Andreas@68155
   329
                  else (term, rhs_free)
Andreas@68155
   330
                | _ => (term, rhs_free)
Andreas@68155
   331
            end
Andreas@68155
   332
        val (part1, (old_pat, part2)) = chop n pats ||> (fn xs => (hd xs, tl xs))
Andreas@68155
   333
        val (new_pat, old_pat1) = replace tcos pat old_pat
Andreas@68155
   334
      in
Andreas@68155
   335
        (part1 @ [new_pat] @ part2, old_pat1)
Andreas@68155
   336
      end                               
Andreas@68155
   337
    val (lhss1, lazy_pats) = map (replace_pat_at pos lazy_case_arg) lhss
Andreas@68155
   338
      |> split_list
Andreas@68155
   339
Andreas@68155
   340
    (* Step 2 : Split patterns *)
Andreas@68155
   341
    fun split equs =
Andreas@68155
   342
      let
Andreas@68155
   343
        fun merge_pattern (Const (s1, T1), Const (s2, _)) =
Andreas@68155
   344
            if s1 = s2 then SOME (Const (s1, T1)) else NONE
Andreas@68155
   345
          | merge_pattern (t, Free _) = SOME t
Andreas@68155
   346
          | merge_pattern (Free _, t) = SOME t
Andreas@68155
   347
          | merge_pattern (t1l $ t1r, t2l $ t2r) =
Andreas@68155
   348
            (case (merge_pattern (t1l, t2l), merge_pattern (t1r, t2r)) of
Andreas@68155
   349
              (SOME t1, SOME t2) => SOME (t1 $ t2)
Andreas@68155
   350
              | _ => NONE)
Andreas@68155
   351
          | merge_pattern _ = NONE
Andreas@68155
   352
        fun merge_patterns pats1 pats2 = case (pats1, pats2) of
Andreas@68155
   353
          ([], []) => SOME []
Andreas@68155
   354
          | (x :: xs, y :: ys) =>
Andreas@68155
   355
            (case (merge_pattern (x, y), merge_patterns xs ys) of
Andreas@68155
   356
              (SOME x, SOME xs) => SOME (x :: xs)
Andreas@68155
   357
              | _ => NONE
Andreas@68155
   358
            )
Andreas@68155
   359
          | _ => raise Match
Andreas@68155
   360
        fun merge_insert ((lhs1, case_pat), _) [] =
Andreas@68155
   361
            [(lhs1, pattern_empty |> pattern_insert ctr_count [case_pat])]
Andreas@68155
   362
          | merge_insert ((lhs1, case_pat), rhs) ((lhs2, pat) :: pats) =
Andreas@68155
   363
            let
Andreas@68155
   364
              val pats = merge_insert ((lhs1, case_pat), rhs) pats
Andreas@68155
   365
              val (first_equ_needed, new_lhs) = case merge_patterns lhs1 lhs2 of
Andreas@68155
   366
                SOME new_lhs => (not (pattern_contains [case_pat] pat), new_lhs)
Andreas@68155
   367
                | NONE => (false, lhs2)
Andreas@68155
   368
              val second_equ_needed = not (matches_all lhs1 lhs2)
Andreas@68155
   369
                orelse not first_equ_needed
Andreas@68155
   370
              val first_equ = if first_equ_needed
Andreas@68155
   371
                then [(new_lhs, pattern_insert ctr_count [case_pat] pat)]
Andreas@68155
   372
                else []
Andreas@68155
   373
              val second_equ = if second_equ_needed
Andreas@68155
   374
                then [(lhs2, pat)]
Andreas@68155
   375
                else []
Andreas@68155
   376
            in
Andreas@68155
   377
              first_equ @ second_equ @ pats
Andreas@68155
   378
            end
Andreas@68155
   379
        in
Andreas@68155
   380
          (fold merge_insert equs []
Andreas@68155
   381
            |> split_list
Andreas@68155
   382
            |> fst) @ [default_lhs]
Andreas@68155
   383
        end
Andreas@68155
   384
    val lhss2 = split ((lhss1 ~~ lazy_pats) ~~ rhss)
Andreas@68155
   385
Andreas@68155
   386
    (* Step 3 : Remove redundant patterns *)
Andreas@68155
   387
    fun remove_redundant_lhs lhss =
Andreas@68155
   388
      let
Andreas@68155
   389
        fun f lhs pat = if pattern_contains lhs pat
Andreas@68155
   390
          then ((lhs, false), pat)
Andreas@68155
   391
          else ((lhs, true), pattern_insert ctr_count lhs pat)
Andreas@68155
   392
      in
Andreas@68155
   393
        fold_map f lhss pattern_empty
Andreas@68155
   394
        |> fst
Andreas@68155
   395
        |> filter snd
Andreas@68155
   396
        |> map fst
Andreas@68155
   397
      end
Andreas@68155
   398
    fun remove_redundant_rhs rhss =
Andreas@68155
   399
      let
Andreas@68155
   400
        fun f (lhs, rhs) pat = if pattern_contains [lhs] pat
Andreas@68155
   401
          then (((lhs, rhs), false), pat)
Andreas@68155
   402
          else (((lhs, rhs), true), pattern_insert ctr_count [lhs] pat)
Andreas@68155
   403
      in
Andreas@68155
   404
        map fst (filter snd (fold_map f rhss pattern_empty |> fst))
Andreas@68155
   405
      end
Andreas@68155
   406
    val lhss3 = remove_redundant_lhs lhss2
Andreas@68155
   407
Andreas@68155
   408
    (* Step 4 : Compute right hand side *)
Andreas@68155
   409
    fun subs_fun f = fold_term
Andreas@68155
   410
      Const
Andreas@68155
   411
      (f o Free)
Andreas@68155
   412
      Var
Andreas@68155
   413
      Bound
Andreas@68155
   414
      Abs
Andreas@68155
   415
      (fn (x, y) => f x $ f y)
Andreas@68155
   416
    fun find_rhss lhs =
Andreas@68155
   417
      let
Andreas@68155
   418
        fun f (lhs1, (pat, rhs)) = 
Andreas@68155
   419
          case match_all lhs1 lhs of NONE => NONE
Andreas@68155
   420
          | SOME f => SOME (pat, subs_fun f rhs)
Andreas@68155
   421
      in
Andreas@68155
   422
        remove_redundant_rhs
Andreas@68155
   423
          (map_filter f (lhss1 ~~ (lazy_pats ~~ rhss)) @
Andreas@68155
   424
            [(lazy_case_arg, list_comb (fun_t, lhs) |> abort)]
Andreas@68155
   425
          )
Andreas@68155
   426
      end
Andreas@68155
   427
Andreas@68155
   428
    (* Step 5 : make_case of right hand side *)
Andreas@68155
   429
    fun make_case ctxt case_arg cases = case cases of
Andreas@68155
   430
      [(Free x, rhs)] => subs_fun (fn y => if y = Free x then case_arg else y) rhs
Andreas@68155
   431
      | _ => Case_Translation.make_case
Andreas@68155
   432
        ctxt
Andreas@68155
   433
        Case_Translation.Warning
Andreas@68155
   434
        Name.context
Andreas@68155
   435
        case_arg
Andreas@68155
   436
        cases
Andreas@68155
   437
    val type_name_fun = add_type_names lazy_pats o type_name_fun
Andreas@68155
   438
    val rhss3 = map ((make_case ctxt lazy_case_arg) o find_rhss) lhss3
Andreas@68155
   439
  in
Andreas@68155
   440
    (lhss3, rhss3, type_name_fun)
Andreas@68155
   441
  end;
Andreas@68155
   442
Andreas@68155
   443
fun terms_to_case ctxt ctr_count (head : term) (lhss : term list list)
Andreas@68155
   444
    (rhss : term list) (typ_list : typ list) (poss : (int * (string * int) list) list) =
Andreas@68155
   445
  let
Andreas@68155
   446
    val (lhss1, rhss1, def_frees, case_args, ctxt1) = replace_frees lhss rhss typ_list ctxt
Andreas@68155
   447
    val exec_list = poss ~~ def_frees
Andreas@68155
   448
    val (lhss2, rhss2, type_name_fun) = fold_rev
Andreas@68155
   449
      (terms_to_case_at ctr_count ctxt1 head case_args) exec_list (lhss1, rhss1, I)
Andreas@68155
   450
    fun make_eq_term (lhss, rhs) = (list_comb (head, lhss), rhs)
Andreas@68155
   451
      |> HOLogic.mk_eq
Andreas@68155
   452
      |> HOLogic.mk_Trueprop
Andreas@68155
   453
  in
Andreas@68155
   454
    (map make_eq_term (lhss2 ~~ rhss2),
Andreas@68155
   455
      get_split_theorems ctxt1 (type_name_fun Symtab.empty),
Andreas@68155
   456
      ctxt1)
Andreas@68155
   457
  end;
Andreas@68155
   458
Andreas@68155
   459
fun build_case_t replace_ctr ctr_count head lhss rhss ctxt =
Andreas@68155
   460
  let
Andreas@68155
   461
    val num_eqs = length lhss
Andreas@68155
   462
    val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
Andreas@68155
   463
      else raise Fail
Andreas@68155
   464
        ("expected same number of left-hand sides as right-hand sides\n"
Andreas@68155
   465
          ^ "and at least one equation")
Andreas@68155
   466
    val n = length (hd lhss)
Andreas@68155
   467
    val _ = if forall (fn m => length m = n) lhss then ()
Andreas@68155
   468
      else raise Fail "expected equal number of arguments"
Andreas@68155
   469
Andreas@68155
   470
    fun to_coordinates (n, ts) = case map_filter (term_to_coordinates replace_ctr) ts of
Andreas@68155
   471
        [] => NONE
Andreas@68155
   472
      | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
Andreas@68155
   473
    fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
Andreas@68155
   474
    val (typ_list, poss) = lhss
Andreas@68155
   475
      |> Ctr_Sugar_Util.transpose
Andreas@68155
   476
      |> map_index to_coordinates
Andreas@68155
   477
      |> map_filter (map_option add_T)
Andreas@68155
   478
      |> flat
Andreas@68155
   479
      |> split_list 
Andreas@68155
   480
  in
Andreas@68155
   481
    if null poss then ([], [], ctxt)
Andreas@68155
   482
    else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss
Andreas@68155
   483
  end;
Andreas@68155
   484
Andreas@68155
   485
fun tac ctxt {splits, intros, defs} =
Andreas@68155
   486
  let
Andreas@68155
   487
    val split_and_subst = 
Andreas@68155
   488
      split_tac ctxt splits 
Andreas@68155
   489
      THEN' REPEAT_ALL_NEW (
Andreas@68155
   490
        resolve_tac ctxt [@{thm conjI}, @{thm allI}]
Andreas@68155
   491
        ORELSE'
Andreas@68155
   492
        (resolve_tac ctxt [@{thm impI}] THEN' hyp_subst_tac_thin true ctxt))
Andreas@68155
   493
  in
Andreas@68155
   494
    (REPEAT_ALL_NEW split_and_subst ORELSE' K all_tac)
Andreas@68155
   495
    THEN' (K (Local_Defs.unfold_tac ctxt [@{thm missing_pattern_match_def}]))
Andreas@68155
   496
    THEN' (K (Local_Defs.unfold_tac ctxt defs))
Andreas@68155
   497
    THEN_ALL_NEW (SOLVED' (resolve_tac ctxt (@{thm refl} :: intros)))
Andreas@68155
   498
  end;
Andreas@68155
   499
Andreas@68155
   500
fun to_case _ _ _ [] = NONE
Andreas@68155
   501
  | to_case ctxt replace_ctr ctr_count ths =
Andreas@68155
   502
    let
Andreas@68155
   503
      val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
Andreas@68155
   504
      fun import [] ctxt = ([], ctxt)
Andreas@68155
   505
        | import (thm :: thms) ctxt =
Andreas@68155
   506
          let
Andreas@68155
   507
            val fun_ct = strip_eq #> fst #> head_of #> Logic.mk_term #> Thm.cterm_of ctxt
Andreas@68155
   508
            val ct = fun_ct thm
Andreas@68155
   509
            val cts = map fun_ct thms
Andreas@68155
   510
            val pairs = map (fn s => (s,ct)) cts
Andreas@68155
   511
            val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
Andreas@68155
   512
          in
Andreas@68155
   513
            Variable.import true (thm :: thms') ctxt |> apfst snd
Andreas@68155
   514
          end
Andreas@68155
   515
Andreas@68155
   516
      val (iths, ctxt') = import ths ctxt
Andreas@68155
   517
      val head = hd iths |> strip_eq |> fst |> head_of
Andreas@68155
   518
      val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
Andreas@68155
   519
Andreas@68155
   520
      fun hide_rhs ((pat, rhs), name) lthy =
Andreas@68155
   521
        let
Andreas@68155
   522
          val frees = fold Term.add_frees pat []
Andreas@68155
   523
          val abs_rhs = fold absfree frees rhs
Andreas@68155
   524
          val (f, def, lthy') = case lthy
Andreas@68155
   525
            |> Local_Defs.define [((Binding.name name, NoSyn), (Binding.empty_atts, abs_rhs))] of
Andreas@68155
   526
              ([(f, (_, def))], lthy') => (f, def, lthy')
Andreas@68155
   527
              | _ => raise Match
Andreas@68155
   528
        in
Andreas@68155
   529
          ((list_comb (f, map Free (rev frees)), def), lthy')
Andreas@68155
   530
        end
Andreas@68155
   531
Andreas@68155
   532
      val rhs_names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs)
Andreas@68155
   533
      val ((def_ts, def_thms), ctxt2) =
Andreas@68155
   534
        fold_map hide_rhs (eqs ~~ rhs_names) ctxt' |> apfst split_list
Andreas@68155
   535
      val (ts, split_thms, ctxt3) = build_case_t replace_ctr ctr_count head
Andreas@68155
   536
        (map fst eqs) def_ts ctxt2
Andreas@68155
   537
      fun mk_thm t = Goal.prove ctxt3 [] [] t
Andreas@68155
   538
          (fn {context=ctxt, ...} => tac ctxt {splits=split_thms, intros=ths, defs=def_thms} 1)
Andreas@68155
   539
    in
Andreas@68155
   540
      if null ts then NONE
Andreas@68155
   541
      else
Andreas@68155
   542
        ts
Andreas@68155
   543
        |> map mk_thm
Andreas@68155
   544
        |> Proof_Context.export ctxt3 ctxt
Andreas@68155
   545
        |> map (Goal.norm_result ctxt)
Andreas@68155
   546
        |> SOME
Andreas@68155
   547
    end;
Andreas@68155
   548
Andreas@68155
   549
end