src/HOL/Tools/datatype_case.ML
author paulson
Tue May 22 17:56:06 2007 +0200 (2007-05-22)
changeset 23075 69e30a7e8880
parent 22779 9ac0ca736969
child 24349 0dd8782fb02d
permissions -rw-r--r--
Some hacks for SPASS format
berghofe@22779
     1
(*  Title:      HOL/Tools/datatype_case.ML
berghofe@22779
     2
    ID:         $Id$
berghofe@22779
     3
    Author:     Konrad Slind, Cambridge University Computer Laboratory
berghofe@22779
     4
                Stefan Berghofer, TU Muenchen
berghofe@22779
     5
berghofe@22779
     6
Nested case expressions on datatypes.
berghofe@22779
     7
*)
berghofe@22779
     8
berghofe@22779
     9
signature DATATYPE_CASE =
berghofe@22779
    10
sig
berghofe@22779
    11
  val make_case: (string -> DatatypeAux.datatype_info option) ->
berghofe@22779
    12
    Proof.context -> bool -> string list -> term -> (term * term) list ->
berghofe@22779
    13
    term * (term * (int * bool)) list
berghofe@22779
    14
  val dest_case: (string -> DatatypeAux.datatype_info option) -> bool ->
berghofe@22779
    15
    string list -> term -> (term * (term * term) list) option
berghofe@22779
    16
  val strip_case: (string -> DatatypeAux.datatype_info option) -> bool ->
berghofe@22779
    17
    term -> (term * (term * term) list) option
berghofe@22779
    18
  val case_tr: (theory -> string -> DatatypeAux.datatype_info option) ->
berghofe@22779
    19
    Proof.context -> term list -> term
berghofe@22779
    20
  val case_tr': (theory -> string -> DatatypeAux.datatype_info option) ->
berghofe@22779
    21
    string -> Proof.context -> term list -> term
berghofe@22779
    22
end;
berghofe@22779
    23
berghofe@22779
    24
structure DatatypeCase : DATATYPE_CASE =
berghofe@22779
    25
struct
berghofe@22779
    26
berghofe@22779
    27
exception CASE_ERROR of string * int;
berghofe@22779
    28
berghofe@22779
    29
fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
berghofe@22779
    30
berghofe@22779
    31
(*---------------------------------------------------------------------------
berghofe@22779
    32
 * Get information about datatypes
berghofe@22779
    33
 *---------------------------------------------------------------------------*)
berghofe@22779
    34
berghofe@22779
    35
fun ty_info (tab : string -> DatatypeAux.datatype_info option) s =
berghofe@22779
    36
  case tab s of
berghofe@22779
    37
    SOME {descr, case_name, index, sorts, ...} =>
berghofe@22779
    38
      let
berghofe@22779
    39
        val (_, (tname, dts, constrs)) = nth descr index;
berghofe@22779
    40
        val mk_ty = DatatypeAux.typ_of_dtyp descr sorts;
berghofe@22779
    41
        val T = Type (tname, map mk_ty dts)
berghofe@22779
    42
      in
berghofe@22779
    43
        SOME {case_name = case_name,
berghofe@22779
    44
          constructors = map (fn (cname, dts') =>
berghofe@22779
    45
            Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs}
berghofe@22779
    46
      end
berghofe@22779
    47
  | NONE => NONE;
berghofe@22779
    48
berghofe@22779
    49
berghofe@22779
    50
(*---------------------------------------------------------------------------
berghofe@22779
    51
 * Each pattern carries with it a tag (i,b) where
berghofe@22779
    52
 * i is the clause it came from and
berghofe@22779
    53
 * b=true indicates that clause was given by the user
berghofe@22779
    54
 * (or is an instantiation of a user supplied pattern)
berghofe@22779
    55
 * b=false --> i = ~1
berghofe@22779
    56
 *---------------------------------------------------------------------------*)
berghofe@22779
    57
berghofe@22779
    58
fun pattern_map f (tm,x) = (f tm, x);
berghofe@22779
    59
berghofe@22779
    60
fun pattern_subst theta = pattern_map (subst_free theta);
berghofe@22779
    61
berghofe@22779
    62
fun row_of_pat x = fst (snd x);
berghofe@22779
    63
berghofe@22779
    64
fun add_row_used ((prfx, pats), (tm, tag)) used =
berghofe@22779
    65
  foldl add_term_free_names (foldl add_term_free_names
berghofe@22779
    66
    (add_term_free_names (tm, used)) pats) prfx;
berghofe@22779
    67
berghofe@22779
    68
(* try to preserve names given by user *)
berghofe@22779
    69
fun default_names names ts =
berghofe@22779
    70
  map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
berghofe@22779
    71
berghofe@22779
    72
fun strip_constraints (Const ("_constrain", _) $ t $ tT) =
berghofe@22779
    73
      strip_constraints t ||> cons tT
berghofe@22779
    74
  | strip_constraints t = (t, []);
berghofe@22779
    75
berghofe@22779
    76
fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $
berghofe@22779
    77
  (Syntax.free "fun" $ tT $ Syntax.free "dummy");
berghofe@22779
    78
berghofe@22779
    79
berghofe@22779
    80
(*---------------------------------------------------------------------------
berghofe@22779
    81
 * Produce an instance of a constructor, plus genvars for its arguments.
berghofe@22779
    82
 *---------------------------------------------------------------------------*)
berghofe@22779
    83
fun fresh_constr ty_match ty_inst colty used c =
berghofe@22779
    84
  let
berghofe@22779
    85
    val (_, Ty) = dest_Const c
berghofe@22779
    86
    val Ts = binder_types Ty;
berghofe@22779
    87
    val names = Name.variant_list used
berghofe@22779
    88
      (DatatypeProp.make_tnames (map Logic.unvarifyT Ts));
berghofe@22779
    89
    val ty = body_type Ty;
berghofe@22779
    90
    val ty_theta = ty_match ty colty handle Type.TYPE_MATCH =>
berghofe@22779
    91
      raise CASE_ERROR ("type mismatch", ~1)
berghofe@22779
    92
    val c' = ty_inst ty_theta c
berghofe@22779
    93
    val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
berghofe@22779
    94
  in (c', gvars)
berghofe@22779
    95
  end;
berghofe@22779
    96
berghofe@22779
    97
berghofe@22779
    98
(*---------------------------------------------------------------------------
berghofe@22779
    99
 * Goes through a list of rows and picks out the ones beginning with a
berghofe@22779
   100
 * pattern with constructor = name.
berghofe@22779
   101
 *---------------------------------------------------------------------------*)
berghofe@22779
   102
fun mk_group (name, T) rows =
berghofe@22779
   103
  let val k = length (binder_types T)
berghofe@22779
   104
  in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
berghofe@22779
   105
    fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of
berghofe@22779
   106
        (Const (name', _), args) =>
berghofe@22779
   107
          if name = name' then
berghofe@22779
   108
            if length args = k then
berghofe@22779
   109
              let val (args', cnstrts') = split_list (map strip_constraints args)
berghofe@22779
   110
              in
berghofe@22779
   111
                ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
berghofe@22779
   112
                 (default_names names args', map2 append cnstrts cnstrts'))
berghofe@22779
   113
              end
berghofe@22779
   114
            else raise CASE_ERROR
berghofe@22779
   115
              ("Wrong number of arguments for constructor " ^ name, i)
berghofe@22779
   116
          else ((in_group, row :: not_in_group), (names, cnstrts))
berghofe@22779
   117
      | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
berghofe@22779
   118
    rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
berghofe@22779
   119
  end;
berghofe@22779
   120
berghofe@22779
   121
(*---------------------------------------------------------------------------
berghofe@22779
   122
 * Partition the rows. Not efficient: we should use hashing.
berghofe@22779
   123
 *---------------------------------------------------------------------------*)
berghofe@22779
   124
fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
berghofe@22779
   125
  | partition ty_match ty_inst type_of used constructors colty res_ty
berghofe@22779
   126
        (rows as (((prfx, _ :: rstp), _) :: _)) =
berghofe@22779
   127
      let
berghofe@22779
   128
        fun part {constrs = [], rows = [], A} = rev A
berghofe@22779
   129
          | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} =
berghofe@22779
   130
              raise CASE_ERROR ("Not a constructor pattern", i)
berghofe@22779
   131
          | part {constrs = c :: crst, rows, A} =
berghofe@22779
   132
              let
berghofe@22779
   133
                val ((in_group, not_in_group), (names, cnstrts)) =
berghofe@22779
   134
                  mk_group (dest_Const c) rows;
berghofe@22779
   135
                val used' = fold add_row_used in_group used;
berghofe@22779
   136
                val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
berghofe@22779
   137
                val in_group' =
berghofe@22779
   138
                  if null in_group  (* Constructor not given *)
berghofe@22779
   139
                  then
berghofe@22779
   140
                    let
berghofe@22779
   141
                      val Ts = map type_of rstp;
berghofe@22779
   142
                      val xs = Name.variant_list
berghofe@22779
   143
                        (foldl add_term_free_names used' gvars)
berghofe@22779
   144
                        (replicate (length rstp) "x")
berghofe@22779
   145
                    in
berghofe@22779
   146
                      [((prfx, gvars @ map Free (xs ~~ Ts)),
berghofe@22779
   147
                        (Const ("HOL.undefined", res_ty), (~1, false)))]
berghofe@22779
   148
                    end
berghofe@22779
   149
                  else in_group
berghofe@22779
   150
              in
berghofe@22779
   151
                part{constrs = crst,
berghofe@22779
   152
                  rows = not_in_group,
berghofe@22779
   153
                  A = {constructor = c',
berghofe@22779
   154
                    new_formals = gvars,
berghofe@22779
   155
                    names = names,
berghofe@22779
   156
                    constraints = cnstrts,
berghofe@22779
   157
                    group = in_group'} :: A}
berghofe@22779
   158
              end
berghofe@22779
   159
      in part {constrs = constructors, rows = rows, A = []}
berghofe@22779
   160
      end;
berghofe@22779
   161
berghofe@22779
   162
(*---------------------------------------------------------------------------
berghofe@22779
   163
 * Misc. routines used in mk_case
berghofe@22779
   164
 *---------------------------------------------------------------------------*)
berghofe@22779
   165
berghofe@22779
   166
fun mk_pat ((c, c'), l) =
berghofe@22779
   167
  let
berghofe@22779
   168
    val L = length (binder_types (fastype_of c))
berghofe@22779
   169
    fun build (prfx, tag, plist) =
berghofe@22779
   170
      let val (args, plist') = chop L plist
berghofe@22779
   171
      in (prfx, tag, list_comb (c', args) :: plist') end
berghofe@22779
   172
  in map build l end;
berghofe@22779
   173
berghofe@22779
   174
fun v_to_prfx (prfx, v::pats) = (v::prfx,pats)
berghofe@22779
   175
  | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
berghofe@22779
   176
berghofe@22779
   177
fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats)
berghofe@22779
   178
  | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1);
berghofe@22779
   179
berghofe@22779
   180
berghofe@22779
   181
(*----------------------------------------------------------------------------
berghofe@22779
   182
 * Translation of pattern terms into nested case expressions.
berghofe@22779
   183
 *
berghofe@22779
   184
 * This performs the translation and also builds the full set of patterns.
berghofe@22779
   185
 * Thus it supports the construction of induction theorems even when an
berghofe@22779
   186
 * incomplete set of patterns is given.
berghofe@22779
   187
 *---------------------------------------------------------------------------*)
berghofe@22779
   188
berghofe@22779
   189
fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
berghofe@22779
   190
  let
berghofe@22779
   191
    val name = Name.variant used "a";
berghofe@22779
   192
    fun expand constructors used ty ((_, []), _) =
berghofe@22779
   193
          raise CASE_ERROR ("mk_case: expand_var_row", ~1)
berghofe@22779
   194
      | expand constructors used ty (row as ((prfx, p :: rst), rhs)) =
berghofe@22779
   195
          if is_Free p then
berghofe@22779
   196
            let
berghofe@22779
   197
              val used' = add_row_used row used;
berghofe@22779
   198
              fun expnd c =
berghofe@22779
   199
                let val capp =
berghofe@22779
   200
                  list_comb (fresh_constr ty_match ty_inst ty used' c)
berghofe@22779
   201
                in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs)
berghofe@22779
   202
                end
berghofe@22779
   203
            in map expnd constructors end
berghofe@22779
   204
          else [row]
berghofe@22779
   205
    fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1)
berghofe@22779
   206
      | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} =  (* Done *)
berghofe@22779
   207
          ([(prfx, tag, [])], tm)
berghofe@22779
   208
      | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} =
berghofe@22779
   209
          mk {path = path, rows = [row]}
berghofe@22779
   210
      | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} =
berghofe@22779
   211
          let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows
berghofe@22779
   212
          in case Option.map (apfst head_of)
berghofe@22779
   213
            (find_first (not o is_Free o fst) col0) of
berghofe@22779
   214
              NONE =>
berghofe@22779
   215
                let
berghofe@22779
   216
                  val rows' = map (fn ((v, _), row) => row ||>
berghofe@22779
   217
                    pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
berghofe@22779
   218
                  val (pref_patl, tm) = mk {path = rstp, rows = rows'}
berghofe@22779
   219
                in (map v_to_pats pref_patl, tm) end
berghofe@22779
   220
            | SOME (Const (cname, cT), i) => (case ty_info tab cname of
berghofe@22779
   221
                NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
berghofe@22779
   222
              | SOME {case_name, constructors} =>
berghofe@22779
   223
                let
berghofe@22779
   224
                  val pty = body_type cT;
berghofe@22779
   225
                  val used' = foldl add_term_free_names used rstp;
berghofe@22779
   226
                  val nrows = maps (expand constructors used' pty) rows;
berghofe@22779
   227
                  val subproblems = partition ty_match ty_inst type_of used'
berghofe@22779
   228
                    constructors pty range_ty nrows;
berghofe@22779
   229
                  val new_formals = map #new_formals subproblems
berghofe@22779
   230
                  val constructors' = map #constructor subproblems
berghofe@22779
   231
                  val news = map (fn {new_formals, group, ...} =>
berghofe@22779
   232
                    {path = new_formals @ rstp, rows = group}) subproblems;
berghofe@22779
   233
                  val (pat_rect, dtrees) = split_list (map mk news);
berghofe@22779
   234
                  val case_functions = map2
berghofe@22779
   235
                    (fn {new_formals, names, constraints, ...} =>
berghofe@22779
   236
                       fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
berghofe@22779
   237
                         Abs (if s = "" then name else s, T,
berghofe@22779
   238
                           abstract_over (x, t)) |>
berghofe@22779
   239
                         fold mk_fun_constrain cnstrts)
berghofe@22779
   240
                           (new_formals ~~ names ~~ constraints))
berghofe@22779
   241
                    subproblems dtrees;
berghofe@22779
   242
                  val types = map type_of (case_functions @ [u]);
berghofe@22779
   243
                  val case_const = Const (case_name, types ---> range_ty)
berghofe@22779
   244
                  val tree = list_comb (case_const, case_functions @ [u])
berghofe@22779
   245
                  val pat_rect1 = flat (map mk_pat
berghofe@22779
   246
                    (constructors ~~ constructors' ~~ pat_rect))
berghofe@22779
   247
                in (pat_rect1, tree)
berghofe@22779
   248
                end)
berghofe@22779
   249
            | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
berghofe@22779
   250
                ProofContext.string_of_term ctxt t, i)
berghofe@22779
   251
          end
berghofe@22779
   252
      | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
berghofe@22779
   253
  in mk
berghofe@22779
   254
  end;
berghofe@22779
   255
berghofe@22779
   256
fun case_error s = error ("Error in case expression:\n" ^ s);
berghofe@22779
   257
berghofe@22779
   258
(* Repeated variable occurrences in a pattern are not allowed. *)
berghofe@22779
   259
fun no_repeat_vars ctxt pat = fold_aterms
berghofe@22779
   260
  (fn x as Free (s, _) => (fn xs =>
berghofe@22779
   261
        if member op aconv xs x then
berghofe@22779
   262
          case_error (quote s ^ " occurs repeatedly in the pattern " ^
berghofe@22779
   263
            quote (ProofContext.string_of_term ctxt pat))
berghofe@22779
   264
        else x :: xs)
berghofe@22779
   265
    | _ => I) pat [];
berghofe@22779
   266
berghofe@22779
   267
fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses =
berghofe@22779
   268
  let
berghofe@22779
   269
    fun string_of_clause (pat, rhs) = ProofContext.string_of_term ctxt
berghofe@22779
   270
      (Syntax.const "_case1" $ pat $ rhs);
berghofe@22779
   271
    val _ = map (no_repeat_vars ctxt o fst) clauses;
berghofe@22779
   272
    val rows = map_index (fn (i, (pat, rhs)) =>
berghofe@22779
   273
      (([], [pat]), (rhs, (i, true)))) clauses;
berghofe@22779
   274
    val rangeT = (case distinct op = (map (type_of o snd) clauses) of
berghofe@22779
   275
        [] => case_error "no clauses given"
berghofe@22779
   276
      | [T] => T
berghofe@22779
   277
      | _ => case_error "all cases must have the same result type");
berghofe@22779
   278
    val used' = fold add_row_used rows used;
berghofe@22779
   279
    val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of
berghofe@22779
   280
        used' rangeT {path = [x], rows = rows}
berghofe@22779
   281
      handle CASE_ERROR (msg, i) => case_error (msg ^
berghofe@22779
   282
        (if i < 0 then ""
berghofe@22779
   283
         else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
berghofe@22779
   284
    val patts1 = map
berghofe@22779
   285
      (fn (_, tag, [pat]) => (pat, tag)
berghofe@22779
   286
        | _ => case_error "error in pattern-match translation") patts;
berghofe@22779
   287
    val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1
berghofe@22779
   288
    val finals = map row_of_pat patts2
berghofe@22779
   289
    val originals = map (row_of_pat o #2) rows
berghofe@22779
   290
    val _ = case originals \\ finals of
berghofe@22779
   291
        [] => ()
berghofe@22779
   292
      | is => (if err then case_error else warning)
berghofe@22779
   293
          ("The following clauses are redundant (covered by preceding clauses):\n" ^
berghofe@22779
   294
           space_implode "\n" (map (string_of_clause o nth clauses) is));
berghofe@22779
   295
  in
berghofe@22779
   296
    (case_tm, patts2)
berghofe@22779
   297
  end;
berghofe@22779
   298
berghofe@22779
   299
fun make_case tab ctxt = gen_make_case
berghofe@22779
   300
  (match_type (ProofContext.theory_of ctxt)) Envir.subst_TVars fastype_of tab ctxt;
berghofe@22779
   301
val make_case_untyped = gen_make_case (K (K Vartab.empty))
berghofe@22779
   302
  (K (Term.map_types (K dummyT))) (K dummyT);
berghofe@22779
   303
berghofe@22779
   304
berghofe@22779
   305
(* parse translation *)
berghofe@22779
   306
berghofe@22779
   307
fun case_tr tab_of ctxt [t, u] =
berghofe@22779
   308
    let
berghofe@22779
   309
      val thy = ProofContext.theory_of ctxt;
berghofe@22779
   310
      (* replace occurrences of dummy_pattern by distinct variables *)
berghofe@22779
   311
      (* internalize constant names                                 *)
berghofe@22779
   312
      fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used =
berghofe@22779
   313
            let val (t', used') = prep_pat t used
berghofe@22779
   314
            in (c $ t' $ tT, used') end
berghofe@22779
   315
        | prep_pat (Const ("dummy_pattern", T)) used =
berghofe@22779
   316
            let val x = Name.variant used "x"
berghofe@22779
   317
            in (Free (x, T), x :: used) end
berghofe@22779
   318
        | prep_pat (Const (s, T)) used =
berghofe@22779
   319
            (case try (unprefix Syntax.constN) s of
berghofe@22779
   320
               SOME c => (Const (c, T), used)
berghofe@22779
   321
             | NONE => (Const (Sign.intern_const thy s, T), used))
berghofe@22779
   322
        | prep_pat (v as Free (s, T)) used =
berghofe@22779
   323
            let val s' = Sign.intern_const thy s
berghofe@22779
   324
            in
berghofe@22779
   325
              if Sign.declared_const thy s' then
berghofe@22779
   326
                (Const (s', T), used)
berghofe@22779
   327
              else (v, used)
berghofe@22779
   328
            end
berghofe@22779
   329
        | prep_pat (t $ u) used =
berghofe@22779
   330
            let
berghofe@22779
   331
              val (t', used') = prep_pat t used;
berghofe@22779
   332
              val (u', used'') = prep_pat u used'
berghofe@22779
   333
            in
berghofe@22779
   334
              (t' $ u', used'')
berghofe@22779
   335
            end
berghofe@22779
   336
        | prep_pat t used = case_error ("Bad pattern: " ^
berghofe@22779
   337
            ProofContext.string_of_term ctxt t);
berghofe@22779
   338
      fun dest_case1 (t as Const ("_case1", _) $ l $ r) =
berghofe@22779
   339
            let val (l', cnstrts) = strip_constraints l
berghofe@22779
   340
            in ((fst (prep_pat l' (add_term_free_names (t, []))), r), cnstrts)
berghofe@22779
   341
            end
berghofe@22779
   342
        | dest_case1 t = case_error "dest_case1";
berghofe@22779
   343
      fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
berghofe@22779
   344
        | dest_case2 t = [t];
berghofe@22779
   345
      val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
berghofe@22779
   346
      val (case_tm, _) = make_case_untyped (tab_of thy) ctxt true []
berghofe@22779
   347
        (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT)
berghofe@22779
   348
           (flat cnstrts) t) cases;
berghofe@22779
   349
    in case_tm end
berghofe@22779
   350
  | case_tr _ _ ts = case_error "case_tr";
berghofe@22779
   351
berghofe@22779
   352
berghofe@22779
   353
(*---------------------------------------------------------------------------
berghofe@22779
   354
 * Pretty printing of nested case expressions
berghofe@22779
   355
 *---------------------------------------------------------------------------*)
berghofe@22779
   356
berghofe@22779
   357
(* destruct one level of pattern matching *)
berghofe@22779
   358
berghofe@22779
   359
fun gen_dest_case name_of type_of tab d used t =
berghofe@22779
   360
  case apfst name_of (strip_comb t) of
berghofe@22779
   361
    (SOME cname, ts as _ :: _) =>
berghofe@22779
   362
      let
berghofe@22779
   363
        val (fs, x) = split_last ts;
berghofe@22779
   364
        fun strip_abs i t =
berghofe@22779
   365
          let
berghofe@22779
   366
            val zs = strip_abs_vars t;
berghofe@22779
   367
            val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
berghofe@22779
   368
            val (xs, ys) = chop i zs;
berghofe@22779
   369
            val u = list_abs (ys, strip_abs_body t);
berghofe@22779
   370
            val xs' = map Free (Name.variant_list (add_term_names (u, used))
berghofe@22779
   371
              (map fst xs) ~~ map snd xs)
berghofe@22779
   372
          in (xs', subst_bounds (rev xs', u)) end;
berghofe@22779
   373
        fun is_dependent i t =
berghofe@22779
   374
          let val k = length (strip_abs_vars t) - i
berghofe@22779
   375
          in k < 0 orelse exists (fn j => j >= k)
berghofe@22779
   376
            (loose_bnos (strip_abs_body t))
berghofe@22779
   377
          end;
berghofe@22779
   378
        fun count_cases (_, _, true) = I
berghofe@22779
   379
          | count_cases (c, (_, body), false) =
berghofe@22779
   380
              AList.map_default op aconv (body, []) (cons c);
berghofe@22779
   381
        val is_undefined = name_of #> equal (SOME "HOL.undefined");
berghofe@22779
   382
        fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
berghofe@22779
   383
      in case ty_info tab cname of
berghofe@22779
   384
          SOME {constructors, case_name} =>
berghofe@22779
   385
            if length fs = length constructors then
berghofe@22779
   386
              let
berghofe@22779
   387
                val cases = map (fn (Const (s, U), t) =>
berghofe@22779
   388
                  let
berghofe@22779
   389
                    val k = length (binder_types U);
berghofe@22779
   390
                    val p as (xs, _) = strip_abs k t
berghofe@22779
   391
                  in
berghofe@22779
   392
                    (Const (s, map type_of xs ---> type_of x),
berghofe@22779
   393
                     p, is_dependent k t)
berghofe@22779
   394
                  end) (constructors ~~ fs);
berghofe@22779
   395
                val cases' = sort (int_ord o swap o pairself (length o snd))
berghofe@22779
   396
                  (fold_rev count_cases cases []);
berghofe@22779
   397
                val R = type_of t;
berghofe@22779
   398
                val dummy = if d then Const ("dummy_pattern", R)
berghofe@22779
   399
                  else Free (Name.variant used "x", R)
berghofe@22779
   400
              in
berghofe@22779
   401
                SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of
berghofe@22779
   402
                  SOME (_, cs) =>
berghofe@22779
   403
                  if length cs = length constructors then [hd cases]
berghofe@22779
   404
                  else filter_out (fn (_, (_, body), _) => is_undefined body) cases
berghofe@22779
   405
                | NONE => case cases' of
berghofe@22779
   406
                  [] => cases
berghofe@22779
   407
                | (default, cs) :: _ =>
berghofe@22779
   408
                  if length cs = 1 then cases
berghofe@22779
   409
                  else if length cs = length constructors then
berghofe@22779
   410
                    [hd cases, (dummy, ([], default), false)]
berghofe@22779
   411
                  else
berghofe@22779
   412
                    filter_out (fn (c, _, _) => member op aconv cs c) cases @
berghofe@22779
   413
                    [(dummy, ([], default), false)]))
berghofe@22779
   414
              end handle CASE_ERROR _ => NONE
berghofe@22779
   415
            else NONE
berghofe@22779
   416
        | _ => NONE
berghofe@22779
   417
      end
berghofe@22779
   418
  | _ => NONE;
berghofe@22779
   419
berghofe@22779
   420
val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
berghofe@22779
   421
val dest_case' = gen_dest_case
berghofe@22779
   422
  (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT);
berghofe@22779
   423
berghofe@22779
   424
berghofe@22779
   425
(* destruct nested patterns *)
berghofe@22779
   426
berghofe@22779
   427
fun strip_case' dest (pat, rhs) =
berghofe@22779
   428
  case dest (add_term_free_names (pat, [])) rhs of
berghofe@22779
   429
    SOME (exp as Free _, clauses) =>
berghofe@22779
   430
      if member op aconv (term_frees pat) exp andalso
berghofe@22779
   431
        not (exists (fn (_, rhs') =>
berghofe@22779
   432
          member op aconv (term_frees rhs') exp) clauses)
berghofe@22779
   433
      then
berghofe@22779
   434
        maps (strip_case' dest) (map (fn (pat', rhs') =>
berghofe@22779
   435
          (subst_free [(exp, pat')] pat, rhs')) clauses)
berghofe@22779
   436
      else [(pat, rhs)]
berghofe@22779
   437
  | _ => [(pat, rhs)];
berghofe@22779
   438
berghofe@22779
   439
fun gen_strip_case dest t = case dest [] t of
berghofe@22779
   440
    SOME (x, clauses) =>
berghofe@22779
   441
      SOME (x, maps (strip_case' dest) clauses)
berghofe@22779
   442
  | NONE => NONE;
berghofe@22779
   443
berghofe@22779
   444
val strip_case = gen_strip_case oo dest_case;
berghofe@22779
   445
val strip_case' = gen_strip_case oo dest_case';
berghofe@22779
   446
berghofe@22779
   447
berghofe@22779
   448
(* print translation *)
berghofe@22779
   449
berghofe@22779
   450
fun case_tr' tab_of cname ctxt ts =
berghofe@22779
   451
  let
berghofe@22779
   452
    val thy = ProofContext.theory_of ctxt;
berghofe@22779
   453
    val consts = ProofContext.consts_of ctxt;
berghofe@22779
   454
    fun mk_clause (pat, rhs) =
berghofe@22779
   455
      let val xs = term_frees pat
berghofe@22779
   456
      in
berghofe@22779
   457
        Syntax.const "_case1" $
berghofe@22779
   458
          map_aterms
berghofe@22779
   459
            (fn Free p => Syntax.mark_boundT p
berghofe@22779
   460
              | Const (s, _) => Const (Consts.extern_early consts s, dummyT)
berghofe@22779
   461
              | t => t) pat $
berghofe@22779
   462
          map_aterms
berghofe@22779
   463
            (fn x as Free (s, _) =>
berghofe@22779
   464
                  if member op aconv xs x then Syntax.mark_bound s else x
berghofe@22779
   465
              | t => t) rhs
berghofe@22779
   466
      end
berghofe@22779
   467
  in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
berghofe@22779
   468
      SOME (x, clauses) => Syntax.const "_case_syntax" $ x $
berghofe@22779
   469
        foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u)
berghofe@22779
   470
          (map mk_clause clauses)
berghofe@22779
   471
    | NONE => raise Match
berghofe@22779
   472
  end;
berghofe@22779
   473
berghofe@22779
   474
end;