src/HOL/Tools/Datatype/datatype_case.ML
author haftmann
Thu Oct 08 19:33:03 2009 +0200 (2009-10-08)
changeset 32896 99cd75a18b78
parent 32671 fbd224850767
child 32952 aeb1e44fbc19
permissions -rw-r--r--
lookup for datatype constructors considers type annotations to resolve overloading
     1 (*  Title:      HOL/Tools/datatype_case.ML
     2     Author:     Konrad Slind, Cambridge University Computer Laboratory
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Nested case expressions on datatypes.
     6 *)
     7 
     8 signature DATATYPE_CASE =
     9 sig
    10   datatype config = Error | Warning | Quiet;
    11   val make_case: (string * typ -> DatatypeAux.info option) ->
    12     Proof.context -> config -> string list -> term -> (term * term) list ->
    13     term * (term * (int * bool)) list
    14   val dest_case: (string -> DatatypeAux.info option) -> bool ->
    15     string list -> term -> (term * (term * term) list) option
    16   val strip_case: (string -> DatatypeAux.info option) -> bool ->
    17     term -> (term * (term * term) list) option
    18   val case_tr: bool -> (theory -> string * typ -> DatatypeAux.info option)
    19     -> Proof.context -> term list -> term
    20   val case_tr': (theory -> string -> DatatypeAux.info option) ->
    21     string -> Proof.context -> term list -> term
    22 end;
    23 
    24 structure DatatypeCase : DATATYPE_CASE =
    25 struct
    26 
    27 datatype config = Error | Warning | Quiet;
    28 
    29 exception CASE_ERROR of string * int;
    30 
    31 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
    32 
    33 (*---------------------------------------------------------------------------
    34  * Get information about datatypes
    35  *---------------------------------------------------------------------------*)
    36 
    37 fun ty_info tab sT =
    38   case tab sT of
    39     SOME ({descr, case_name, index, sorts, ...} : DatatypeAux.info) =>
    40       let
    41         val (_, (tname, dts, constrs)) = nth descr index;
    42         val mk_ty = DatatypeAux.typ_of_dtyp descr sorts;
    43         val T = Type (tname, map mk_ty dts)
    44       in
    45         SOME {case_name = case_name,
    46           constructors = map (fn (cname, dts') =>
    47             Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs}
    48       end
    49   | NONE => NONE;
    50 
    51 
    52 (*---------------------------------------------------------------------------
    53  * Each pattern carries with it a tag (i,b) where
    54  * i is the clause it came from and
    55  * b=true indicates that clause was given by the user
    56  * (or is an instantiation of a user supplied pattern)
    57  * b=false --> i = ~1
    58  *---------------------------------------------------------------------------*)
    59 
    60 fun pattern_subst theta (tm, x) = (subst_free theta tm, x);
    61 
    62 fun row_of_pat x = fst (snd x);
    63 
    64 fun add_row_used ((prfx, pats), (tm, tag)) =
    65   fold Term.add_free_names (tm :: pats @ prfx);
    66 
    67 (* try to preserve names given by user *)
    68 fun default_names names ts =
    69   map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
    70 
    71 fun strip_constraints (Const ("_constrain", _) $ t $ tT) =
    72       strip_constraints t ||> cons tT
    73   | strip_constraints t = (t, []);
    74 
    75 fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $
    76   (Syntax.free "fun" $ tT $ Syntax.free "dummy");
    77 
    78 
    79 (*---------------------------------------------------------------------------
    80  * Produce an instance of a constructor, plus genvars for its arguments.
    81  *---------------------------------------------------------------------------*)
    82 fun fresh_constr ty_match ty_inst colty used c =
    83   let
    84     val (_, Ty) = dest_Const c
    85     val Ts = binder_types Ty;
    86     val names = Name.variant_list used
    87       (DatatypeProp.make_tnames (map Logic.unvarifyT Ts));
    88     val ty = body_type Ty;
    89     val ty_theta = ty_match ty colty handle Type.TYPE_MATCH =>
    90       raise CASE_ERROR ("type mismatch", ~1)
    91     val c' = ty_inst ty_theta c
    92     val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
    93   in (c', gvars)
    94   end;
    95 
    96 
    97 (*---------------------------------------------------------------------------
    98  * Goes through a list of rows and picks out the ones beginning with a
    99  * pattern with constructor = name.
   100  *---------------------------------------------------------------------------*)
   101 fun mk_group (name, T) rows =
   102   let val k = length (binder_types T)
   103   in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) =>
   104     fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of
   105         (Const (name', _), args) =>
   106           if name = name' then
   107             if length args = k then
   108               let val (args', cnstrts') = split_list (map strip_constraints args)
   109               in
   110                 ((((prfx, args' @ rst), rhs) :: in_group, not_in_group),
   111                  (default_names names args', map2 append cnstrts cnstrts'))
   112               end
   113             else raise CASE_ERROR
   114               ("Wrong number of arguments for constructor " ^ name, i)
   115           else ((in_group, row :: not_in_group), (names, cnstrts))
   116       | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
   117     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   118   end;
   119 
   120 (*---------------------------------------------------------------------------
   121  * Partition the rows. Not efficient: we should use hashing.
   122  *---------------------------------------------------------------------------*)
   123 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
   124   | partition ty_match ty_inst type_of used constructors colty res_ty
   125         (rows as (((prfx, _ :: rstp), _) :: _)) =
   126       let
   127         fun part {constrs = [], rows = [], A} = rev A
   128           | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} =
   129               raise CASE_ERROR ("Not a constructor pattern", i)
   130           | part {constrs = c :: crst, rows, A} =
   131               let
   132                 val ((in_group, not_in_group), (names, cnstrts)) =
   133                   mk_group (dest_Const c) rows;
   134                 val used' = fold add_row_used in_group used;
   135                 val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
   136                 val in_group' =
   137                   if null in_group  (* Constructor not given *)
   138                   then
   139                     let
   140                       val Ts = map type_of rstp;
   141                       val xs = Name.variant_list
   142                         (fold Term.add_free_names gvars used')
   143                         (replicate (length rstp) "x")
   144                     in
   145                       [((prfx, gvars @ map Free (xs ~~ Ts)),
   146                         (Const ("HOL.undefined", res_ty), (~1, false)))]
   147                     end
   148                   else in_group
   149               in
   150                 part{constrs = crst,
   151                   rows = not_in_group,
   152                   A = {constructor = c',
   153                     new_formals = gvars,
   154                     names = names,
   155                     constraints = cnstrts,
   156                     group = in_group'} :: A}
   157               end
   158       in part {constrs = constructors, rows = rows, A = []}
   159       end;
   160 
   161 (*---------------------------------------------------------------------------
   162  * Misc. routines used in mk_case
   163  *---------------------------------------------------------------------------*)
   164 
   165 fun mk_pat ((c, c'), l) =
   166   let
   167     val L = length (binder_types (fastype_of c))
   168     fun build (prfx, tag, plist) =
   169       let val (args, plist') = chop L plist
   170       in (prfx, tag, list_comb (c', args) :: plist') end
   171   in map build l end;
   172 
   173 fun v_to_prfx (prfx, v::pats) = (v::prfx,pats)
   174   | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
   175 
   176 fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats)
   177   | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1);
   178 
   179 
   180 (*----------------------------------------------------------------------------
   181  * Translation of pattern terms into nested case expressions.
   182  *
   183  * This performs the translation and also builds the full set of patterns.
   184  * Thus it supports the construction of induction theorems even when an
   185  * incomplete set of patterns is given.
   186  *---------------------------------------------------------------------------*)
   187 
   188 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
   189   let
   190     val name = Name.variant used "a";
   191     fun expand constructors used ty ((_, []), _) =
   192           raise CASE_ERROR ("mk_case: expand_var_row", ~1)
   193       | expand constructors used ty (row as ((prfx, p :: rst), rhs)) =
   194           if is_Free p then
   195             let
   196               val used' = add_row_used row used;
   197               fun expnd c =
   198                 let val capp =
   199                   list_comb (fresh_constr ty_match ty_inst ty used' c)
   200                 in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs)
   201                 end
   202             in map expnd constructors end
   203           else [row]
   204     fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1)
   205       | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} =  (* Done *)
   206           ([(prfx, tag, [])], tm)
   207       | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} =
   208           mk {path = path, rows = [row]}
   209       | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} =
   210           let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows
   211           in case Option.map (apfst head_of)
   212             (find_first (not o is_Free o fst) col0) of
   213               NONE =>
   214                 let
   215                   val rows' = map (fn ((v, _), row) => row ||>
   216                     pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
   217                   val (pref_patl, tm) = mk {path = rstp, rows = rows'}
   218                 in (map v_to_pats pref_patl, tm) end
   219             | SOME (Const (cname, cT), i) => (case ty_info tab (cname, cT) of
   220                 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
   221               | SOME {case_name, constructors} =>
   222                 let
   223                   val pty = body_type cT;
   224                   val used' = fold Term.add_free_names rstp used;
   225                   val nrows = maps (expand constructors used' pty) rows;
   226                   val subproblems = partition ty_match ty_inst type_of used'
   227                     constructors pty range_ty nrows;
   228                   val new_formals = map #new_formals subproblems
   229                   val constructors' = map #constructor subproblems
   230                   val news = map (fn {new_formals, group, ...} =>
   231                     {path = new_formals @ rstp, rows = group}) subproblems;
   232                   val (pat_rect, dtrees) = split_list (map mk news);
   233                   val case_functions = map2
   234                     (fn {new_formals, names, constraints, ...} =>
   235                        fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
   236                          Abs (if s = "" then name else s, T,
   237                            abstract_over (x, t)) |>
   238                          fold mk_fun_constrain cnstrts)
   239                            (new_formals ~~ names ~~ constraints))
   240                     subproblems dtrees;
   241                   val types = map type_of (case_functions @ [u]);
   242                   val case_const = Const (case_name, types ---> range_ty)
   243                   val tree = list_comb (case_const, case_functions @ [u])
   244                   val pat_rect1 = flat (map mk_pat
   245                     (constructors ~~ constructors' ~~ pat_rect))
   246                 in (pat_rect1, tree)
   247                 end)
   248             | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^
   249                 Syntax.string_of_term ctxt t, i)
   250           end
   251       | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
   252   in mk
   253   end;
   254 
   255 fun case_error s = error ("Error in case expression:\n" ^ s);
   256 
   257 (* Repeated variable occurrences in a pattern are not allowed. *)
   258 fun no_repeat_vars ctxt pat = fold_aterms
   259   (fn x as Free (s, _) => (fn xs =>
   260         if member op aconv xs x then
   261           case_error (quote s ^ " occurs repeatedly in the pattern " ^
   262             quote (Syntax.string_of_term ctxt pat))
   263         else x :: xs)
   264     | _ => I) pat [];
   265 
   266 fun gen_make_case ty_match ty_inst type_of tab ctxt config used x clauses =
   267   let
   268     fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt
   269       (Syntax.const "_case1" $ pat $ rhs);
   270     val _ = map (no_repeat_vars ctxt o fst) clauses;
   271     val rows = map_index (fn (i, (pat, rhs)) =>
   272       (([], [pat]), (rhs, (i, true)))) clauses;
   273     val rangeT = (case distinct op = (map (type_of o snd) clauses) of
   274         [] => case_error "no clauses given"
   275       | [T] => T
   276       | _ => case_error "all cases must have the same result type");
   277     val used' = fold add_row_used rows used;
   278     val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of
   279         used' rangeT {path = [x], rows = rows}
   280       handle CASE_ERROR (msg, i) => case_error (msg ^
   281         (if i < 0 then ""
   282          else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
   283     val patts1 = map
   284       (fn (_, tag, [pat]) => (pat, tag)
   285         | _ => case_error "error in pattern-match translation") patts;
   286     val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1
   287     val finals = map row_of_pat patts2
   288     val originals = map (row_of_pat o #2) rows
   289     val _ = case originals \\ finals of
   290         [] => ()
   291         | is => (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
   292           ("The following clauses are redundant (covered by preceding clauses):\n" ^
   293            cat_lines (map (string_of_clause o nth clauses) is));
   294   in
   295     (case_tm, patts2)
   296   end;
   297 
   298 fun make_case tab ctxt = gen_make_case
   299   (match_type (ProofContext.theory_of ctxt)) Envir.subst_term_types fastype_of tab ctxt;
   300 val make_case_untyped = gen_make_case (K (K Vartab.empty))
   301   (K (Term.map_types (K dummyT))) (K dummyT);
   302 
   303 
   304 (* parse translation *)
   305 
   306 fun case_tr err tab_of ctxt [t, u] =
   307     let
   308       val thy = ProofContext.theory_of ctxt;
   309       (* replace occurrences of dummy_pattern by distinct variables *)
   310       (* internalize constant names                                 *)
   311       fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used =
   312             let val (t', used') = prep_pat t used
   313             in (c $ t' $ tT, used') end
   314         | prep_pat (Const ("dummy_pattern", T)) used =
   315             let val x = Name.variant used "x"
   316             in (Free (x, T), x :: used) end
   317         | prep_pat (Const (s, T)) used =
   318             (case try (unprefix Syntax.constN) s of
   319                SOME c => (Const (c, T), used)
   320              | NONE => (Const (Sign.intern_const thy s, T), used))
   321         | prep_pat (v as Free (s, T)) used =
   322             let val s' = Sign.intern_const thy s
   323             in
   324               if Sign.declared_const thy s' then
   325                 (Const (s', T), used)
   326               else (v, used)
   327             end
   328         | prep_pat (t $ u) used =
   329             let
   330               val (t', used') = prep_pat t used;
   331               val (u', used'') = prep_pat u used'
   332             in
   333               (t' $ u', used'')
   334             end
   335         | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
   336       fun dest_case1 (t as Const ("_case1", _) $ l $ r) =
   337             let val (l', cnstrts) = strip_constraints l
   338             in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts)
   339             end
   340         | dest_case1 t = case_error "dest_case1";
   341       fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
   342         | dest_case2 t = [t];
   343       val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
   344       val (case_tm, _) = make_case_untyped (tab_of thy) ctxt
   345         (if err then Error else Warning) []
   346         (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT)
   347            (flat cnstrts) t) cases;
   348     in case_tm end
   349   | case_tr _ _ _ ts = case_error "case_tr";
   350 
   351 
   352 (*---------------------------------------------------------------------------
   353  * Pretty printing of nested case expressions
   354  *---------------------------------------------------------------------------*)
   355 
   356 (* destruct one level of pattern matching *)
   357 
   358 fun gen_dest_case name_of type_of tab d used t =
   359   case apfst name_of (strip_comb t) of
   360     (SOME cname, ts as _ :: _) =>
   361       let
   362         val (fs, x) = split_last ts;
   363         fun strip_abs i t =
   364           let
   365             val zs = strip_abs_vars t;
   366             val _ = if length zs < i then raise CASE_ERROR ("", 0) else ();
   367             val (xs, ys) = chop i zs;
   368             val u = list_abs (ys, strip_abs_body t);
   369             val xs' = map Free (Name.variant_list (OldTerm.add_term_names (u, used))
   370               (map fst xs) ~~ map snd xs)
   371           in (xs', subst_bounds (rev xs', u)) end;
   372         fun is_dependent i t =
   373           let val k = length (strip_abs_vars t) - i
   374           in k < 0 orelse exists (fn j => j >= k)
   375             (loose_bnos (strip_abs_body t))
   376           end;
   377         fun count_cases (_, _, true) = I
   378           | count_cases (c, (_, body), false) =
   379               AList.map_default op aconv (body, []) (cons c);
   380         val is_undefined = name_of #> equal (SOME "HOL.undefined");
   381         fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body)
   382       in case ty_info tab cname of
   383           SOME {constructors, case_name} =>
   384             if length fs = length constructors then
   385               let
   386                 val cases = map (fn (Const (s, U), t) =>
   387                   let
   388                     val k = length (binder_types U);
   389                     val p as (xs, _) = strip_abs k t
   390                   in
   391                     (Const (s, map type_of xs ---> type_of x),
   392                      p, is_dependent k t)
   393                   end) (constructors ~~ fs);
   394                 val cases' = sort (int_ord o swap o pairself (length o snd))
   395                   (fold_rev count_cases cases []);
   396                 val R = type_of t;
   397                 val dummy = if d then Const ("dummy_pattern", R)
   398                   else Free (Name.variant used "x", R)
   399               in
   400                 SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of
   401                   SOME (_, cs) =>
   402                   if length cs = length constructors then [hd cases]
   403                   else filter_out (fn (_, (_, body), _) => is_undefined body) cases
   404                 | NONE => case cases' of
   405                   [] => cases
   406                 | (default, cs) :: _ =>
   407                   if length cs = 1 then cases
   408                   else if length cs = length constructors then
   409                     [hd cases, (dummy, ([], default), false)]
   410                   else
   411                     filter_out (fn (c, _, _) => member op aconv cs c) cases @
   412                     [(dummy, ([], default), false)]))
   413               end handle CASE_ERROR _ => NONE
   414             else NONE
   415         | _ => NONE
   416       end
   417   | _ => NONE;
   418 
   419 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
   420 val dest_case' = gen_dest_case
   421   (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT);
   422 
   423 
   424 (* destruct nested patterns *)
   425 
   426 fun strip_case'' dest (pat, rhs) =
   427   case dest (Term.add_free_names pat []) rhs of
   428     SOME (exp as Free _, clauses) =>
   429       if member op aconv (OldTerm.term_frees pat) exp andalso
   430         not (exists (fn (_, rhs') =>
   431           member op aconv (OldTerm.term_frees rhs') exp) clauses)
   432       then
   433         maps (strip_case'' dest) (map (fn (pat', rhs') =>
   434           (subst_free [(exp, pat')] pat, rhs')) clauses)
   435       else [(pat, rhs)]
   436   | _ => [(pat, rhs)];
   437 
   438 fun gen_strip_case dest t = case dest [] t of
   439     SOME (x, clauses) =>
   440       SOME (x, maps (strip_case'' dest) clauses)
   441   | NONE => NONE;
   442 
   443 val strip_case = gen_strip_case oo dest_case;
   444 val strip_case' = gen_strip_case oo dest_case';
   445 
   446 
   447 (* print translation *)
   448 
   449 fun case_tr' tab_of cname ctxt ts =
   450   let
   451     val thy = ProofContext.theory_of ctxt;
   452     val consts = ProofContext.consts_of ctxt;
   453     fun mk_clause (pat, rhs) =
   454       let val xs = Term.add_frees pat []
   455       in
   456         Syntax.const "_case1" $
   457           map_aterms
   458             (fn Free p => Syntax.mark_boundT p
   459               | Const (s, _) => Const (Consts.extern_early consts s, dummyT)
   460               | t => t) pat $
   461           map_aterms
   462             (fn x as Free (s, T) =>
   463                   if member (op =) xs (s, T) then Syntax.mark_bound s else x
   464               | t => t) rhs
   465       end
   466   in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of
   467       SOME (x, clauses) => Syntax.const "_case_syntax" $ x $
   468         foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u)
   469           (map mk_clause clauses)
   470     | NONE => raise Match
   471   end;
   472 
   473 end;