src/HOL/Tools/case_translation.ML
changeset 51673 4dfa00e264d8
child 51674 2b1498a2ce85
equal deleted inserted replaced
51672:d5c5e088ebdf 51673:4dfa00e264d8
       
     1 (*  Title:      Tools/case_translation.ML
       
     2     Author:     Konrad Slind, Cambridge University Computer Laboratory
       
     3     Author:     Stefan Berghofer, TU Muenchen
       
     4     Author:     Dmitriy Traytel, TU Muenchen
       
     5 
       
     6 Nested case expressions via a generic data slot for case combinators and constructors.
       
     7 *)
       
     8 
       
     9 signature CASE_TRANSLATION =
       
    10 sig
       
    11   datatype config = Error | Warning | Quiet
       
    12   val case_tr: Proof.context -> term list -> term
       
    13   val lookup_by_constr: Proof.context -> string * typ -> (term * term list) option
       
    14   val lookup_by_constr_permissive: Proof.context -> string * typ -> (term * term list) option
       
    15   val lookup_by_case: Proof.context -> string -> (term * term list) option
       
    16   val make_case:  Proof.context -> config -> Name.context -> term -> (term * term) list -> term
       
    17   val print_case_translations: Proof.context -> unit
       
    18   val strip_case: Proof.context -> bool -> term -> term
       
    19   val show_cases: bool Config.T
       
    20   val setup: theory -> theory
       
    21   val register: term -> term list -> Context.generic -> Context.generic
       
    22 end;
       
    23 
       
    24 structure Case_Translation: CASE_TRANSLATION =
       
    25 struct
       
    26 
       
    27 (** data management **)
       
    28 
       
    29 datatype data = Data of
       
    30   {constrs: (string * (term * term list)) list Symtab.table,
       
    31    cases: (term * term list) Symtab.table};
       
    32 
       
    33 fun make_data (constrs, cases) = Data {constrs = constrs, cases = cases};
       
    34 
       
    35 structure Data = Generic_Data
       
    36 (
       
    37   type T = data;
       
    38   val empty = make_data (Symtab.empty, Symtab.empty);
       
    39   val extend = I;
       
    40   fun merge
       
    41     (Data {constrs = constrs1, cases = cases1},
       
    42      Data {constrs = constrs2, cases = cases2}) =
       
    43     make_data
       
    44       (Symtab.join (K (AList.merge (op =) (K true))) (constrs1, constrs2),
       
    45       Symtab.merge (K true) (cases1, cases2));
       
    46 );
       
    47 
       
    48 fun map_data f =
       
    49   Data.map (fn Data {constrs, cases} => make_data (f (constrs, cases)));
       
    50 fun map_constrs f = map_data (fn (constrs, cases) => (f constrs, cases));
       
    51 fun map_cases f = map_data (fn (constrs, cases) => (constrs, f cases));
       
    52 
       
    53 val rep_data = (fn Data args => args) o Data.get o Context.Proof;
       
    54 
       
    55 fun T_of_data (comb, constrs) =
       
    56   fastype_of comb
       
    57   |> funpow (length constrs) range_type
       
    58   |> domain_type;
       
    59 
       
    60 val Tname_of_data = fst o dest_Type o T_of_data;
       
    61 
       
    62 val constrs_of = #constrs o rep_data;
       
    63 val cases_of = #cases o rep_data;
       
    64 
       
    65 fun lookup_by_constr ctxt (c, T) =
       
    66   let
       
    67     val tab = Symtab.lookup_list (constrs_of ctxt) c;
       
    68   in
       
    69     (case body_type T of
       
    70       Type (tyco, _) => AList.lookup (op =) tab tyco
       
    71     | _ => NONE)
       
    72   end;
       
    73 
       
    74 fun lookup_by_constr_permissive ctxt (c, T) =
       
    75   let
       
    76     val tab = Symtab.lookup_list (constrs_of ctxt) c;
       
    77     val hint = (case body_type T of Type (tyco, _) => SOME tyco | _ => NONE);
       
    78     val default = if null tab then NONE else SOME (snd (List.last tab));
       
    79     (*conservative wrt. overloaded constructors*)
       
    80   in
       
    81     (case hint of
       
    82       NONE => default
       
    83     | SOME tyco =>
       
    84         (case AList.lookup (op =) tab tyco of
       
    85           NONE => default (*permissive*)
       
    86         | SOME info => SOME info))
       
    87   end;
       
    88 
       
    89 val lookup_by_case = Symtab.lookup o cases_of;
       
    90 
       
    91 
       
    92 (** installation **)
       
    93 
       
    94 fun case_error s = error ("Error in case expression:\n" ^ s);
       
    95 
       
    96 val name_of = try (dest_Const #> fst);
       
    97 
       
    98 (* parse translation *)
       
    99 
       
   100 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
       
   101 
       
   102 fun case_tr ctxt [t, u] =
       
   103       let
       
   104         val thy = Proof_Context.theory_of ctxt;
       
   105 
       
   106         fun is_const s =
       
   107           Sign.declared_const thy (Proof_Context.intern_const ctxt s);
       
   108 
       
   109         fun abs p tTs t = Syntax.const @{const_syntax case_abs} $
       
   110           fold constrain_Abs tTs (absfree p t);
       
   111 
       
   112         fun abs_pat (Const ("_constrain", _) $ t $ tT) tTs = abs_pat t (tT :: tTs)
       
   113           | abs_pat (Free (p as (x, _))) tTs =
       
   114               if is_const x then I else abs p tTs
       
   115           | abs_pat (t $ u) _ = abs_pat u [] #> abs_pat t []
       
   116           | abs_pat _ _ = I;
       
   117 
       
   118         fun dest_case1 (Const (@{syntax_const "_case1"}, _) $ l $ r) =
       
   119               abs_pat l []
       
   120                 (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l $ r)
       
   121           | dest_case1 _ = case_error "dest_case1";
       
   122 
       
   123         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
       
   124           | dest_case2 t = [t];
       
   125       in
       
   126         fold_rev
       
   127           (fn t => fn u =>
       
   128              Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
       
   129           (dest_case2 u)
       
   130           (Syntax.const @{const_syntax case_nil}) $ t
       
   131       end
       
   132   | case_tr _ _ = case_error "case_tr";
       
   133 
       
   134 val trfun_setup =
       
   135   Sign.add_advanced_trfuns ([],
       
   136     [(@{syntax_const "_case_syntax"}, case_tr)],
       
   137     [], []);
       
   138 
       
   139 
       
   140 (* print translation *)
       
   141 
       
   142 fun case_tr' [t, u, x] =
       
   143       let
       
   144         fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
       
   145               let val (s', used') = Name.variant s used
       
   146               in mk_clause t ((s', T) :: xs) used' end
       
   147           | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
       
   148               Syntax.const @{syntax_const "_case1"} $
       
   149                 subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
       
   150                 subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
       
   151 
       
   152         fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
       
   153           | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
       
   154               mk_clauses' t u
       
   155         and mk_clauses' t u =
       
   156               mk_clause t [] (Term.declare_term_frees t Name.context) ::
       
   157               mk_clauses u
       
   158       in
       
   159         Syntax.const @{syntax_const "_case_syntax"} $ x $
       
   160           foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
       
   161             (mk_clauses' t u)
       
   162       end;
       
   163 
       
   164 val trfun_setup' = Sign.add_trfuns
       
   165   ([], [], [(@{const_syntax "case_cons"}, case_tr')], []);
       
   166 
       
   167 
       
   168 (* declarations *)
       
   169 
       
   170 fun register raw_case_comb raw_constrs context =
       
   171   let
       
   172     val ctxt = Context.proof_of context;
       
   173     val case_comb = singleton (Variable.polymorphic ctxt) raw_case_comb;
       
   174     val constrs = Variable.polymorphic ctxt raw_constrs;
       
   175     val case_key = case_comb |> dest_Const |> fst;
       
   176     val constr_keys = map (fst o dest_Const) constrs;
       
   177     val data = (case_comb, constrs);
       
   178     val Tname = Tname_of_data data;
       
   179     val update_constrs = fold (fn key => Symtab.cons_list (key, (Tname, data))) constr_keys;
       
   180     val update_cases = Symtab.update (case_key, data);
       
   181   in
       
   182     context
       
   183     |> map_constrs update_constrs
       
   184     |> map_cases update_cases
       
   185   end;
       
   186 
       
   187 
       
   188 (* (Un)check phases *)
       
   189 
       
   190 datatype config = Error | Warning | Quiet;
       
   191 
       
   192 exception CASE_ERROR of string * int;
       
   193 
       
   194 fun match_type ctxt pat ob =
       
   195   Sign.typ_match (Proof_Context.theory_of ctxt) (pat, ob) Vartab.empty;
       
   196 
       
   197 
       
   198 (*Each pattern carries with it a tag i, which denotes the clause it
       
   199 came from. i = ~1 indicates that the clause was added by pattern
       
   200 completion.*)
       
   201 
       
   202 fun add_row_used ((prfx, pats), (tm, tag)) =
       
   203   fold Term.declare_term_frees (tm :: pats @ map Free prfx);
       
   204 
       
   205 (* try to preserve names given by user *)
       
   206 fun default_name "" (Free (name', _)) = name'
       
   207   | default_name name _ = name;
       
   208 
       
   209 
       
   210 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
       
   211 fun fresh_constr ctxt colty used c =
       
   212   let
       
   213     val (_, T) = dest_Const c;
       
   214     val Ts = binder_types T;
       
   215     val (names, _) = fold_map Name.variant
       
   216       (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
       
   217     val ty = body_type T;
       
   218     val ty_theta = match_type ctxt ty colty
       
   219       handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
       
   220     val c' = Envir.subst_term_types ty_theta c;
       
   221     val gvars = map (Envir.subst_term_types ty_theta o Free) (names ~~ Ts);
       
   222   in (c', gvars) end;
       
   223 
       
   224 (*Go through a list of rows and pick out the ones beginning with a
       
   225   pattern with constructor = name.*)
       
   226 fun mk_group (name, T) rows =
       
   227   let val k = length (binder_types T) in
       
   228     fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
       
   229       fn ((in_group, not_in_group), names) =>
       
   230         (case strip_comb p of
       
   231           (Const (name', _), args) =>
       
   232             if name = name' then
       
   233               if length args = k then
       
   234                 ((((prfx, args @ ps), rhs) :: in_group, not_in_group),
       
   235                  map2 default_name names args)
       
   236               else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
       
   237             else ((in_group, row :: not_in_group), names)
       
   238         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
       
   239     rows (([], []), replicate k "") |>> pairself rev
       
   240   end;
       
   241 
       
   242 
       
   243 (* Partitioning *)
       
   244 
       
   245 fun partition _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
       
   246   | partition ctxt used constructors colty res_ty
       
   247         (rows as (((prfx, _ :: ps), _) :: _)) =
       
   248       let
       
   249         fun part [] [] = []
       
   250           | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
       
   251           | part (c :: cs) rows =
       
   252               let
       
   253                 val ((in_group, not_in_group), names) = mk_group (dest_Const c) rows;
       
   254                 val used' = fold add_row_used in_group used;
       
   255                 val (c', gvars) = fresh_constr ctxt colty used' c;
       
   256                 val in_group' =
       
   257                   if null in_group  (* Constructor not given *)
       
   258                   then
       
   259                     let
       
   260                       val Ts = map fastype_of ps;
       
   261                       val (xs, _) =
       
   262                         fold_map Name.variant
       
   263                           (replicate (length ps) "x")
       
   264                           (fold Term.declare_term_frees gvars used');
       
   265                     in
       
   266                       [((prfx, gvars @ map Free (xs ~~ Ts)),
       
   267                         (Const (@{const_name undefined}, res_ty), ~1))]
       
   268                     end
       
   269                   else in_group;
       
   270               in
       
   271                 {constructor = c',
       
   272                  new_formals = gvars,
       
   273                  names = names,
       
   274                  group = in_group'} :: part cs not_in_group
       
   275               end;
       
   276       in part constructors rows end;
       
   277 
       
   278 fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
       
   279   | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
       
   280 
       
   281 
       
   282 (* Translation of pattern terms into nested case expressions. *)
       
   283 
       
   284 fun mk_case ctxt used range_ty =
       
   285   let
       
   286     val get_info = lookup_by_constr_permissive ctxt;
       
   287 
       
   288     fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1)
       
   289       | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
       
   290           if is_Free p then
       
   291             let
       
   292               val used' = add_row_used row used;
       
   293               fun expnd c =
       
   294                 let val capp = list_comb (fresh_constr ctxt ty used' c)
       
   295                 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
       
   296             in map expnd constructors end
       
   297           else [row];
       
   298 
       
   299     val (name, _) = Name.variant "a" used;
       
   300 
       
   301     fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
       
   302       | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
       
   303       | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
       
   304       | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
       
   305           let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
       
   306             (case Option.map (apfst head_of)
       
   307                 (find_first (not o is_Free o fst) col0) of
       
   308               NONE =>
       
   309                 let
       
   310                   val rows' = map (fn ((v, _), row) => row ||>
       
   311                     apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
       
   312                 in mk us rows' end
       
   313             | SOME (Const (cname, cT), i) =>
       
   314                 (case get_info (cname, cT) of
       
   315                   NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i)
       
   316                 | SOME (case_comb, constructors) =>
       
   317                     let
       
   318                       val pty = body_type cT;
       
   319                       val used' = fold Term.declare_term_frees us used;
       
   320                       val nrows = maps (expand constructors used' pty) rows;
       
   321                       val subproblems =
       
   322                         partition ctxt used' constructors pty range_ty nrows;
       
   323                       val (pat_rect, dtrees) =
       
   324                         split_list (map (fn {new_formals, group, ...} =>
       
   325                           mk (new_formals @ us) group) subproblems);
       
   326                       val case_functions =
       
   327                         map2 (fn {new_formals, names, ...} =>
       
   328                           fold_rev (fn (x as Free (_, T), s) => fn t =>
       
   329                             Abs (if s = "" then name else s, T, abstract_over (x, t)))
       
   330                               (new_formals ~~ names))
       
   331                         subproblems dtrees;
       
   332                       val types = map fastype_of (case_functions @ [u]);
       
   333                       val case_const = Const (name_of case_comb |> the, types ---> range_ty);
       
   334                       val tree = list_comb (case_const, case_functions @ [u]);
       
   335                     in (flat pat_rect, tree) end)
       
   336             | SOME (t, i) =>
       
   337                 raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i))
       
   338           end
       
   339       | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
       
   340   in mk end;
       
   341 
       
   342 
       
   343 (* replace occurrences of dummy_pattern by distinct variables *)
       
   344 fun replace_dummies (Const (@{const_name dummy_pattern}, T)) used =
       
   345       let val (x, used') = Name.variant "x" used
       
   346       in (Free (x, T), used') end
       
   347   | replace_dummies (t $ u) used =
       
   348       let
       
   349         val (t', used') = replace_dummies t used;
       
   350         val (u', used'') = replace_dummies u used';
       
   351       in (t' $ u', used'') end
       
   352   | replace_dummies t used = (t, used);
       
   353 
       
   354 (*Repeated variable occurrences in a pattern are not allowed.*)
       
   355 fun no_repeat_vars ctxt pat = fold_aterms
       
   356   (fn x as Free (s, _) =>
       
   357       (fn xs =>
       
   358         if member op aconv xs x then
       
   359           case_error (quote s ^ " occurs repeatedly in the pattern " ^
       
   360             quote (Syntax.string_of_term ctxt pat))
       
   361         else x :: xs)
       
   362     | _ => I) pat [];
       
   363 
       
   364 fun make_case ctxt config used x clauses =
       
   365   let
       
   366     fun string_of_clause (pat, rhs) =
       
   367       Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
       
   368     val _ = map (no_repeat_vars ctxt o fst) clauses;
       
   369     val (rows, used') = used |>
       
   370       fold (fn (pat, rhs) =>
       
   371         Term.declare_term_frees pat #> Term.declare_term_frees rhs) clauses |>
       
   372       fold_map (fn (i, (pat, rhs)) => fn used =>
       
   373         let val (pat', used') = replace_dummies pat used
       
   374         in ((([], [pat']), (rhs, i)), used') end)
       
   375           (map_index I clauses);
       
   376     val rangeT =
       
   377       (case distinct (op =) (map (fastype_of o snd) clauses) of
       
   378         [] => case_error "no clauses given"
       
   379       | [T] => T
       
   380       | _ => case_error "all cases must have the same result type");
       
   381     val used' = fold add_row_used rows used;
       
   382     val (tags, case_tm) =
       
   383       mk_case ctxt used' rangeT [x] rows
       
   384         handle CASE_ERROR (msg, i) =>
       
   385           case_error
       
   386             (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
       
   387     val _ =
       
   388       (case subtract (op =) tags (map (snd o snd) rows) of
       
   389         [] => ()
       
   390       | is =>
       
   391           (case config of Error => case_error | Warning => warning | Quiet => fn _ => ())
       
   392             ("The following clauses are redundant (covered by preceding clauses):\n" ^
       
   393               cat_lines (map (string_of_clause o nth clauses) is)));
       
   394   in
       
   395     case_tm
       
   396   end;
       
   397 
       
   398 
       
   399 (* term check *)
       
   400 
       
   401 fun decode_clause (Const (@{const_name case_abs}, _) $ Abs (s, T, t)) xs used =
       
   402       let val (s', used') = Name.variant s used
       
   403       in decode_clause t (Free (s', T) :: xs) used' end
       
   404   | decode_clause (Const (@{const_name case_elem}, _) $ t $ u) xs _ =
       
   405       (subst_bounds (xs, t), subst_bounds (xs, u))
       
   406   | decode_clause _ _ _ = case_error "decode_clause";
       
   407 
       
   408 fun decode_cases (Const (@{const_name case_nil}, _)) = []
       
   409   | decode_cases (Const (@{const_name case_cons}, _) $ t $ u) =
       
   410       decode_clause t [] (Term.declare_term_frees t Name.context) ::
       
   411       decode_cases u
       
   412   | decode_cases _ = case_error "decode_cases";
       
   413 
       
   414 fun check_case ctxt =
       
   415   let
       
   416     fun decode_case ((t as Const (@{const_name case_cons}, _) $ _ $ _) $ u) =
       
   417         make_case ctxt Error Name.context (decode_case u) (decode_cases t)
       
   418     | decode_case (t $ u) = decode_case t $ decode_case u
       
   419     | decode_case (Abs (x, T, u)) =
       
   420         let val (x', u') = Term.dest_abs (x, T, u);
       
   421         in Term.absfree (x', T) (decode_case u') end
       
   422     | decode_case t = t;
       
   423   in
       
   424     map decode_case
       
   425   end;
       
   426 
       
   427 val term_check_setup =
       
   428   Context.theory_map (Syntax_Phases.term_check 1 "case" check_case);
       
   429 
       
   430 
       
   431 (* Pretty printing of nested case expressions *)
       
   432 
       
   433 (* destruct one level of pattern matching *)
       
   434 
       
   435 fun dest_case ctxt d used t =
       
   436   (case apfst name_of (strip_comb t) of
       
   437     (SOME cname, ts as _ :: _) =>
       
   438       let
       
   439         val (fs, x) = split_last ts;
       
   440         fun strip_abs i Us t =
       
   441           let
       
   442             val zs = strip_abs_vars t;
       
   443             val j = length zs;
       
   444             val (xs, ys) =
       
   445               if j < i then (zs @ map (pair "x") (drop j Us), [])
       
   446               else chop i zs;
       
   447             val u = fold_rev Term.abs ys (strip_abs_body t);
       
   448             val xs' = map Free
       
   449               ((fold_map Name.variant (map fst xs)
       
   450                   (Term.declare_term_names u used) |> fst) ~~
       
   451                map snd xs);
       
   452             val (xs1, xs2) = chop j xs'
       
   453           in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end;
       
   454         fun is_dependent i t =
       
   455           let val k = length (strip_abs_vars t) - i
       
   456           in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
       
   457         fun count_cases (_, _, true) = I
       
   458           | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c);
       
   459         val is_undefined = name_of #> equal (SOME @{const_name undefined});
       
   460         fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body);
       
   461         val get_info = lookup_by_case ctxt;
       
   462       in
       
   463         (case get_info cname of
       
   464           SOME (_, constructors) =>
       
   465             if length fs = length constructors then
       
   466               let
       
   467                 val cases = map (fn (Const (s, U), t) =>
       
   468                   let
       
   469                     val Us = binder_types U;
       
   470                     val k = length Us;
       
   471                     val p as (xs, _) = strip_abs k Us t;
       
   472                   in
       
   473                     (Const (s, map fastype_of xs ---> fastype_of x), p, is_dependent k t)
       
   474                   end) (constructors ~~ fs);
       
   475                 val cases' =
       
   476                   sort (int_ord o swap o pairself (length o snd))
       
   477                     (fold_rev count_cases cases []);
       
   478                 val R = fastype_of t;
       
   479                 val dummy =
       
   480                   if d then Term.dummy_pattern R
       
   481                   else Free (Name.variant "x" used |> fst, R);
       
   482               in
       
   483                 SOME (x,
       
   484                   map mk_case
       
   485                     (case find_first (is_undefined o fst) cases' of
       
   486                       SOME (_, cs) =>
       
   487                         if length cs = length constructors then [hd cases]
       
   488                         else filter_out (fn (_, (_, body), _) => is_undefined body) cases
       
   489                     | NONE =>
       
   490                         (case cases' of
       
   491                           [] => cases
       
   492                         | (default, cs) :: _ =>
       
   493                             if length cs = 1 then cases
       
   494                             else if length cs = length constructors then
       
   495                               [hd cases, (dummy, ([], default), false)]
       
   496                             else
       
   497                               filter_out (fn (c, _, _) => member op aconv cs c) cases @
       
   498                                 [(dummy, ([], default), false)])))
       
   499               end
       
   500             else NONE
       
   501         | _ => NONE)
       
   502       end
       
   503   | _ => NONE);
       
   504 
       
   505 
       
   506 (* destruct nested patterns *)
       
   507 
       
   508 fun encode_clause S T (pat, rhs) =
       
   509   fold (fn x as (_, U) => fn t =>
       
   510     Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t)
       
   511       (Term.add_frees pat [])
       
   512       (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs);
       
   513 
       
   514 fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T)
       
   515   | encode_cases S T (p :: ps) =
       
   516       Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $
       
   517         encode_clause S T p $ encode_cases S T ps;
       
   518 
       
   519 fun encode_case (t, ps as (pat, rhs) :: _) =
       
   520       encode_cases (fastype_of pat) (fastype_of rhs) ps $ t
       
   521   | encode_case _ = case_error "encode_case";
       
   522 
       
   523 fun strip_case' ctxt d (pat, rhs) =
       
   524   (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of
       
   525     SOME (exp as Free _, clauses) =>
       
   526       if Term.exists_subterm (curry (op aconv) exp) pat andalso
       
   527         not (exists (fn (_, rhs') =>
       
   528           Term.exists_subterm (curry (op aconv) exp) rhs') clauses)
       
   529       then
       
   530         maps (strip_case' ctxt d) (map (fn (pat', rhs') =>
       
   531           (subst_free [(exp, pat')] pat, rhs')) clauses)
       
   532       else [(pat, rhs)]
       
   533   | _ => [(pat, rhs)]);
       
   534 
       
   535 fun strip_case ctxt d t =
       
   536   (case dest_case ctxt d Name.context t of
       
   537     SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses)
       
   538   | NONE =>
       
   539     (case t of
       
   540       (t $ u) => strip_case ctxt d t $ strip_case ctxt d u
       
   541     | (Abs (x, T, u)) =>
       
   542         let val (x', u') = Term.dest_abs (x, T, u);
       
   543         in Term.absfree (x', T) (strip_case ctxt d u') end
       
   544     | _ => t));
       
   545 
       
   546 
       
   547 (* term uncheck *)
       
   548 
       
   549 val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
       
   550 
       
   551 fun uncheck_case ctxt ts =
       
   552   if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts;
       
   553 
       
   554 val term_uncheck_setup =
       
   555   Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);
       
   556 
       
   557 
       
   558 (* theory setup *)
       
   559 
       
   560 val setup =
       
   561   trfun_setup #>
       
   562   trfun_setup' #>
       
   563   term_check_setup #>
       
   564   term_uncheck_setup;
       
   565 
       
   566 
       
   567 (* outer syntax commands *)
       
   568 
       
   569 fun print_case_translations ctxt =
       
   570   let
       
   571     val cases = Symtab.dest (cases_of ctxt);
       
   572     fun show_case (_, data as (comb, ctrs)) =
       
   573       Pretty.big_list
       
   574         (Pretty.string_of (Pretty.block [Pretty.str (Tname_of_data data), Pretty.str ":"]))
       
   575         [Pretty.block [Pretty.brk 3, Pretty.block
       
   576           [Pretty.str "combinator:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt comb)]],
       
   577         Pretty.block [Pretty.brk 3, Pretty.block
       
   578           [Pretty.str "constructors:", Pretty.brk 1,
       
   579              Pretty.list "" "" (map (Pretty.quote o Syntax.pretty_term ctxt) ctrs)]]];
       
   580   in
       
   581     Pretty.big_list "Case translations:" (map show_case cases)
       
   582     |> Pretty.writeln
       
   583   end;
       
   584 
       
   585 val _ =
       
   586   Outer_Syntax.improper_command @{command_spec "print_case_translations"}
       
   587     "print registered case combinators and constructors"
       
   588     (Scan.succeed (Toplevel.keep (print_case_translations o Toplevel.context_of)))
       
   589 
       
   590 end;