src/HOL/Tools/datatype_package/datatype_case.ML
changeset 31604 eb2f9d709296
parent 29623 1219985d24b5
child 31737 b3f63611784e
equal deleted inserted replaced
31603:fa30cd74d7d6 31604:eb2f9d709296
       
     1 (*  Title:      HOL/Tools/datatype_case.ML
       
     2     Author:     Konrad Slind, Cambridge University Computer Laboratory
       
     3     Author:     Stefan Berghofer, TU Muenchen
       
     4 
       
     5 Nested case expressions on datatypes.
       
     6 *)
       
     7 
       
     8 signature DATATYPE_CASE =
       
     9 sig
       
    10   val make_case: (string -> DatatypeAux.datatype_info option) ->
       
    11     Proof.context -> bool -> string list -> term -> (term * term) list ->
       
    12     term * (term * (int * bool)) list
       
    13   val dest_case: (string -> DatatypeAux.datatype_info option) -> bool ->
       
    14     string list -> term -> (term * (term * term) list) option
       
    15   val strip_case: (string -> DatatypeAux.datatype_info option) -> bool ->
       
    16     term -> (term * (term * term) list) option
       
    17   val case_tr: bool -> (theory -> string -> DatatypeAux.datatype_info option)
       
    18     -> Proof.context -> term list -> term
       
    19   val case_tr': (theory -> string -> DatatypeAux.datatype_info option) ->
       
    20     string -> Proof.context -> term list -> term
       
    21 end;
       
    22 
       
    23 structure DatatypeCase : DATATYPE_CASE =
       
    24 struct
       
    25 
       
    26 exception CASE_ERROR of string * int;
       
    27 
       
    28 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
       
    29 
       
    30 (*---------------------------------------------------------------------------
       
    31  * Get information about datatypes
       
    32  *---------------------------------------------------------------------------*)
       
    33 
       
    34 fun ty_info (tab : string -> DatatypeAux.datatype_info option) s =
       
    35   case tab s of
       
    36     SOME {descr, case_name, index, sorts, ...} =>
       
    37       let
       
    38         val (_, (tname, dts, constrs)) = nth descr index;
       
    39         val mk_ty = DatatypeAux.typ_of_dtyp descr sorts;
       
    40         val T = Type (tname, map mk_ty dts)
       
    41       in
       
    42         SOME {case_name = case_name,
       
    43           constructors = map (fn (cname, dts') =>
       
    44             Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs}
       
    45       end
       
    46   | NONE => NONE;
       
    47 
       
    48 
       
    49 (*---------------------------------------------------------------------------
       
    50  * Each pattern carries with it a tag (i,b) where
       
    51  * i is the clause it came from and
       
    52  * b=true indicates that clause was given by the user
       
    53  * (or is an instantiation of a user supplied pattern)
       
    54  * b=false --> i = ~1
       
    55  *---------------------------------------------------------------------------*)
       
    56 
       
    57 fun pattern_subst theta (tm, x) = (subst_free theta tm, x);
       
    58 
       
    59 fun row_of_pat x = fst (snd x);
       
    60 
       
    61 fun add_row_used ((prfx, pats), (tm, tag)) =
       
    62   fold Term.add_free_names (tm :: pats @ prfx);
       
    63 
       
    64 (* try to preserve names given by user *)
       
    65 fun default_names names ts =
       
    66   map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
       
    67 
       
    68 fun strip_constraints (Const ("_constrain", _) $ t $ tT) =
       
    69       strip_constraints t ||> cons tT
       
    70   | strip_constraints t = (t, []);
       
    71 
       
    72 fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $
       
    73   (Syntax.free "fun" $ tT $ Syntax.free "dummy");
       
    74 
       
    75 
       
    76 (*---------------------------------------------------------------------------
       
    77  * Produce an instance of a constructor, plus genvars for its arguments.
       
    78  *---------------------------------------------------------------------------*)
       
    79 fun fresh_constr ty_match ty_inst colty used c =
       
    80   let
       
    81     val (_, Ty) = dest_Const c
       
    82     val Ts = binder_types Ty;
       
    83     val names = Name.variant_list used
       
    84       (DatatypeProp.make_tnames (map Logic.unvarifyT Ts));
       
    85     val ty = body_type Ty;
       
    86     val ty_theta = ty_match ty colty handle Type.TYPE_MATCH =>
       
    87       raise CASE_ERROR ("type mismatch", ~1)
       
    88     val c' = ty_inst ty_theta c
       
    89     val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
       
    90   in (c', gvars)
       
    91   end;
       
    92 
       
    93 
       
    94 (*---------------------------------------------------------------------------
       
    95  * Goes through a list of rows and picks out the ones beginning with a
       
    96  * pattern with constructor = name.
       
    97  *---------------------------------------------------------------------------*)
       
    98 fun mk_group (name, T) rows =
       
    99   let val k = length (binder_types T)
       
   100   in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
       
   101     fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of
       
   102         (Const (name', _), args) =>
       
   103           if name = name' then
       
   104             if length args = k then
       
   105               let val (args', cnstrts') = split_list (map strip_constraints args)
       
   106               in
       
   107                 ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
       
   108                  (default_names names args', map2 append cnstrts cnstrts'))
       
   109               end
       
   110             else raise CASE_ERROR
       
   111               ("Wrong number of arguments for constructor " ^ name, i)
       
   112           else ((in_group, row :: not_in_group), (names, cnstrts))
       
   113       | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
       
   114     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
       
   115   end;
       
   116 
       
   117 (*---------------------------------------------------------------------------
       
   118  * Partition the rows. Not efficient: we should use hashing.
       
   119  *---------------------------------------------------------------------------*)
       
   120 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
       
   121   | partition ty_match ty_inst type_of used constructors colty res_ty
       
   122         (rows as (((prfx, _ :: rstp), _) :: _)) =
       
   123       let
       
   124         fun part {constrs = [], rows = [], A} = rev A
       
   125           | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} =
       
   126               raise CASE_ERROR ("Not a constructor pattern", i)
       
   127           | part {constrs = c :: crst, rows, A} =
       
   128               let
       
   129                 val ((in_group, not_in_group), (names, cnstrts)) =
       
   130                   mk_group (dest_Const c) rows;
       
   131                 val used' = fold add_row_used in_group used;
       
   132                 val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
       
   133                 val in_group' =
       
   134                   if null in_group  (* Constructor not given *)
       
   135                   then
       
   136                     let
       
   137                       val Ts = map type_of rstp;
       
   138                       val xs = Name.variant_list
       
   139                         (fold Term.add_free_names gvars used')
       
   140                         (replicate (length rstp) "x")
       
   141                     in
       
   142                       [((prfx, gvars @ map Free (xs ~~ Ts)),
       
   143                         (Const ("HOL.undefined", res_ty), (~1, false)))]
       
   144                     end
       
   145                   else in_group
       
   146               in
       
   147                 part{constrs = crst,
       
   148                   rows = not_in_group,
       
   149                   A = {constructor = c',
       
   150                     new_formals = gvars,
       
   151                     names = names,
       
   152                     constraints = cnstrts,
       
   153                     group = in_group'} :: A}
       
   154               end
       
   155       in part {constrs = constructors, rows = rows, A = []}
       
   156       end;
       
   157 
       
   158 (*---------------------------------------------------------------------------
       
   159  * Misc. routines used in mk_case
       
   160  *---------------------------------------------------------------------------*)
       
   161 
       
   162 fun mk_pat ((c, c'), l) =
       
   163   let
       
   164     val L = length (binder_types (fastype_of c))
       
   165     fun build (prfx, tag, plist) =
       
   166       let val (args, plist') = chop L plist
       
   167       in (prfx, tag, list_comb (c', args) :: plist') end
       
   168   in map build l end;
       
   169 
       
   170 fun v_to_prfx (prfx, v::pats) = (v::prfx,pats)
       
   171   | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
       
   172 
       
   173 fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats)
       
   174   | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1);
       
   175 
       
   176 
       
   177 (*----------------------------------------------------------------------------
       
   178  * Translation of pattern terms into nested case expressions.
       
   179  *
       
   180  * This performs the translation and also builds the full set of patterns.
       
   181  * Thus it supports the construction of induction theorems even when an
       
   182  * incomplete set of patterns is given.
       
   183  *---------------------------------------------------------------------------*)
       
   184 
       
   185 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
       
   186   let
       
   187     val name = Name.variant used "a";
       
   188     fun expand constructors used ty ((_, []), _) =
       
   189           raise CASE_ERROR ("mk_case: expand_var_row", ~1)
       
   190       | expand constructors used ty (row as ((prfx, p :: rst), rhs)) =
       
   191           if is_Free p then
       
   192             let
       
   193               val used' = add_row_used row used;
       
   194               fun expnd c =
       
   195                 let val capp =
       
   196                   list_comb (fresh_constr ty_match ty_inst ty used' c)
       
   197                 in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs)
       
   198                 end
       
   199             in map expnd constructors end
       
   200           else [row]
       
   201     fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1)
       
   202       | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} =  (* Done *)
       
   203           ([(prfx, tag, [])], tm)
       
   204       | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} =
       
   205           mk {path = path, rows = [row]}
       
   206       | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} =
       
   207           let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows
       
   208           in case Option.map (apfst head_of)
       
   209             (find_first (not o is_Free o fst) col0) of
       
   210               NONE =>
       
   211                 let
       
   212                   val rows' = map (fn ((v, _), row) => row ||>
       
   213                     pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
       
   214                   val (pref_patl, tm) = mk {path = rstp, rows = rows'}
       
   215                 in (map v_to_pats pref_patl, tm) end
       
   216             | SOME (Const (cname, cT), i) => (case ty_info tab cname of
       
   217                 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
       
   218               | SOME {case_name, constructors} =>
       
   219                 let
       
   220                   val pty = body_type cT;
       
   221                   val used' = fold Term.add_free_names rstp used;
       
   222                   val nrows = maps (expand constructors used' pty) rows;
       
   223                   val subproblems = partition ty_match ty_inst type_of used'
       
   224                     constructors pty range_ty nrows;
       
   225                   val new_formals = map #new_formals subproblems
       
   226                   val constructors' = map #constructor subproblems
       
   227                   val news = map (fn {new_formals, group, ...} =>
       
   228                     {path = new_formals @ rstp, rows = group}) subproblems;
       
   229                   val (pat_rect, dtrees) = split_list (map mk news);
       
   230                   val case_functions = map2
       
   231                     (fn {new_formals, names, constraints, ...} =>
       
   232                        fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
       
   233                          Abs (if s = "" then name else s, T,
       
   234                            abstract_over (x, t)) |>
       
   235                          fold mk_fun_constrain cnstrts)
       
   236                            (new_formals ~~ names ~~ constraints))
       
   237                     subproblems dtrees;
       
   238                   val types = map type_of (case_functions @ [u]);
       
   239                   val case_const = Const (case_name, types ---> range_ty)
       
   240                   val tree = list_comb (case_const, case_functions @ [u])
       
   241                   val pat_rect1 = flat (map mk_pat
       
   242                     (constructors ~~ constructors' ~~ pat_rect))
       
   243                 in (pat_rect1, tree)
       
   244                 end)
       
   245             | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
       
   246                 Syntax.string_of_term ctxt t, i)
       
   247           end
       
   248       | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
       
   249   in mk
       
   250   end;
       
   251 
       
   252 fun case_error s = error ("Error in case expression:\n" ^ s);
       
   253 
       
   254 (* Repeated variable occurrences in a pattern are not allowed. *)
       
   255 fun no_repeat_vars ctxt pat = fold_aterms
       
   256   (fn x as Free (s, _) => (fn xs =>
       
   257         if member op aconv xs x then
       
   258           case_error (quote s ^ " occurs repeatedly in the pattern " ^
       
   259             quote (Syntax.string_of_term ctxt pat))
       
   260         else x :: xs)
       
   261     | _ => I) pat [];
       
   262 
       
   263 fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses =
       
   264   let
       
   265     fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt
       
   266       (Syntax.const "_case1" $ pat $ rhs);
       
   267     val _ = map (no_repeat_vars ctxt o fst) clauses;
       
   268     val rows = map_index (fn (i, (pat, rhs)) =>
       
   269       (([], [pat]), (rhs, (i, true)))) clauses;
       
   270     val rangeT = (case distinct op = (map (type_of o snd) clauses) of
       
   271         [] => case_error "no clauses given"
       
   272       | [T] => T
       
   273       | _ => case_error "all cases must have the same result type");
       
   274     val used' = fold add_row_used rows used;
       
   275     val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of
       
   276         used' rangeT {path = [x], rows = rows}
       
   277       handle CASE_ERROR (msg, i) => case_error (msg ^
       
   278         (if i < 0 then ""
       
   279          else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
       
   280     val patts1 = map
       
   281       (fn (_, tag, [pat]) => (pat, tag)
       
   282         | _ => case_error "error in pattern-match translation") patts;
       
   283     val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1
       
   284     val finals = map row_of_pat patts2
       
   285     val originals = map (row_of_pat o #2) rows
       
   286     val _ = case originals \\ finals of
       
   287         [] => ()
       
   288       | is => (if err then case_error else warning)
       
   289           ("The following clauses are redundant (covered by preceding clauses):\n" ^
       
   290            cat_lines (map (string_of_clause o nth clauses) is));
       
   291   in
       
   292     (case_tm, patts2)
       
   293   end;
       
   294 
       
   295 fun make_case tab ctxt = gen_make_case
       
   296   (match_type (ProofContext.theory_of ctxt)) Envir.subst_TVars fastype_of tab ctxt;
       
   297 val make_case_untyped = gen_make_case (K (K Vartab.empty))
       
   298   (K (Term.map_types (K dummyT))) (K dummyT);
       
   299 
       
   300 
       
   301 (* parse translation *)
       
   302 
       
   303 fun case_tr err tab_of ctxt [t, u] =
       
   304     let
       
   305       val thy = ProofContext.theory_of ctxt;
       
   306       (* replace occurrences of dummy_pattern by distinct variables *)
       
   307       (* internalize constant names                                 *)
       
   308       fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used =
       
   309             let val (t', used') = prep_pat t used
       
   310             in (c $ t' $ tT, used') end
       
   311         | prep_pat (Const ("dummy_pattern", T)) used =
       
   312             let val x = Name.variant used "x"
       
   313             in (Free (x, T), x :: used) end
       
   314         | prep_pat (Const (s, T)) used =
       
   315             (case try (unprefix Syntax.constN) s of
       
   316                SOME c => (Const (c, T), used)
       
   317              | NONE => (Const (Sign.intern_const thy s, T), used))
       
   318         | prep_pat (v as Free (s, T)) used =
       
   319             let val s' = Sign.intern_const thy s
       
   320             in
       
   321               if Sign.declared_const thy s' then
       
   322                 (Const (s', T), used)
       
   323               else (v, used)
       
   324             end
       
   325         | prep_pat (t $ u) used =
       
   326             let
       
   327               val (t', used') = prep_pat t used;
       
   328               val (u', used'') = prep_pat u used'
       
   329             in
       
   330               (t' $ u', used'')
       
   331             end
       
   332         | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
       
   333       fun dest_case1 (t as Const ("_case1", _) $ l $ r) =
       
   334             let val (l', cnstrts) = strip_constraints l
       
   335             in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts)
       
   336             end
       
   337         | dest_case1 t = case_error "dest_case1";
       
   338       fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
       
   339         | dest_case2 t = [t];
       
   340       val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
       
   341       val (case_tm, _) = make_case_untyped (tab_of thy) ctxt err []
       
   342         (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT)
       
   343            (flat cnstrts) t) cases;
       
   344     in case_tm end
       
   345   | case_tr _ _ _ ts = case_error "case_tr";
       
   346 
       
   347 
       
   348 (*---------------------------------------------------------------------------
       
   349  * Pretty printing of nested case expressions
       
   350  *---------------------------------------------------------------------------*)
       
   351 
       
   352 (* destruct one level of pattern matching *)
       
   353 
       
   354 fun gen_dest_case name_of type_of tab d used t =
       
   355   case apfst name_of (strip_comb t) of
       
   356     (SOME cname, ts as _ :: _) =>
       
   357       let
       
   358         val (fs, x) = split_last ts;
       
   359         fun strip_abs i t =
       
   360           let
       
   361             val zs = strip_abs_vars t;
       
   362             val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
       
   363             val (xs, ys) = chop i zs;
       
   364             val u = list_abs (ys, strip_abs_body t);
       
   365             val xs' = map Free (Name.variant_list (OldTerm.add_term_names (u, used))
       
   366               (map fst xs) ~~ map snd xs)
       
   367           in (xs', subst_bounds (rev xs', u)) end;
       
   368         fun is_dependent i t =
       
   369           let val k = length (strip_abs_vars t) - i
       
   370           in k < 0 orelse exists (fn j => j >= k)
       
   371             (loose_bnos (strip_abs_body t))
       
   372           end;
       
   373         fun count_cases (_, _, true) = I
       
   374           | count_cases (c, (_, body), false) =
       
   375               AList.map_default op aconv (body, []) (cons c);
       
   376         val is_undefined = name_of #> equal (SOME "HOL.undefined");
       
   377         fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
       
   378       in case ty_info tab cname of
       
   379           SOME {constructors, case_name} =>
       
   380             if length fs = length constructors then
       
   381               let
       
   382                 val cases = map (fn (Const (s, U), t) =>
       
   383                   let
       
   384                     val k = length (binder_types U);
       
   385                     val p as (xs, _) = strip_abs k t
       
   386                   in
       
   387                     (Const (s, map type_of xs ---> type_of x),
       
   388                      p, is_dependent k t)
       
   389                   end) (constructors ~~ fs);
       
   390                 val cases' = sort (int_ord o swap o pairself (length o snd))
       
   391                   (fold_rev count_cases cases []);
       
   392                 val R = type_of t;
       
   393                 val dummy = if d then Const ("dummy_pattern", R)
       
   394                   else Free (Name.variant used "x", R)
       
   395               in
       
   396                 SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of
       
   397                   SOME (_, cs) =>
       
   398                   if length cs = length constructors then [hd cases]
       
   399                   else filter_out (fn (_, (_, body), _) => is_undefined body) cases
       
   400                 | NONE => case cases' of
       
   401                   [] => cases
       
   402                 | (default, cs) :: _ =>
       
   403                   if length cs = 1 then cases
       
   404                   else if length cs = length constructors then
       
   405                     [hd cases, (dummy, ([], default), false)]
       
   406                   else
       
   407                     filter_out (fn (c, _, _) => member op aconv cs c) cases @
       
   408                     [(dummy, ([], default), false)]))
       
   409               end handle CASE_ERROR _ => NONE
       
   410             else NONE
       
   411         | _ => NONE
       
   412       end
       
   413   | _ => NONE;
       
   414 
       
   415 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
       
   416 val dest_case' = gen_dest_case
       
   417   (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT);
       
   418 
       
   419 
       
   420 (* destruct nested patterns *)
       
   421 
       
   422 fun strip_case'' dest (pat, rhs) =
       
   423   case dest (Term.add_free_names pat []) rhs of
       
   424     SOME (exp as Free _, clauses) =>
       
   425       if member op aconv (OldTerm.term_frees pat) exp andalso
       
   426         not (exists (fn (_, rhs') =>
       
   427           member op aconv (OldTerm.term_frees rhs') exp) clauses)
       
   428       then
       
   429         maps (strip_case'' dest) (map (fn (pat', rhs') =>
       
   430           (subst_free [(exp, pat')] pat, rhs')) clauses)
       
   431       else [(pat, rhs)]
       
   432   | _ => [(pat, rhs)];
       
   433 
       
   434 fun gen_strip_case dest t = case dest [] t of
       
   435     SOME (x, clauses) =>
       
   436       SOME (x, maps (strip_case'' dest) clauses)
       
   437   | NONE => NONE;
       
   438 
       
   439 val strip_case = gen_strip_case oo dest_case;
       
   440 val strip_case' = gen_strip_case oo dest_case';
       
   441 
       
   442 
       
   443 (* print translation *)
       
   444 
       
   445 fun case_tr' tab_of cname ctxt ts =
       
   446   let
       
   447     val thy = ProofContext.theory_of ctxt;
       
   448     val consts = ProofContext.consts_of ctxt;
       
   449     fun mk_clause (pat, rhs) =
       
   450       let val xs = Term.add_frees pat []
       
   451       in
       
   452         Syntax.const "_case1" $
       
   453           map_aterms
       
   454             (fn Free p => Syntax.mark_boundT p
       
   455               | Const (s, _) => Const (Consts.extern_early consts s, dummyT)
       
   456               | t => t) pat $
       
   457           map_aterms
       
   458             (fn x as Free (s, T) =>
       
   459                   if member (op =) xs (s, T) then Syntax.mark_bound s else x
       
   460               | t => t) rhs
       
   461       end
       
   462   in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
       
   463       SOME (x, clauses) => Syntax.const "_case_syntax" $ x $
       
   464         foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u)
       
   465           (map mk_clause clauses)
       
   466     | NONE => raise Match
       
   467   end;
       
   468 
       
   469 end;