--- a/src/HOL/HOL.thy Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/HOL.thy Tue Jan 22 14:33:45 2013 +0100
@@ -8,7 +8,8 @@
imports Pure "~~/src/Tools/Code_Generator"
keywords
"try" "solve_direct" "quickcheck"
- "print_coercions" "print_coercion_maps" "print_claset" "print_induct_rules" :: diag and
+ "print_coercions" "print_coercion_maps" "print_claset" "print_induct_rules"
+ "print_case_translations":: diag and
"quickcheck_params" :: thy_decl
begin
--- a/src/HOL/Inductive.thy Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/Inductive.thy Tue Jan 22 14:33:45 2013 +0100
@@ -279,7 +279,8 @@
case_cons :: "('a \<Rightarrow> 'b) \<Rightarrow> ('a \<Rightarrow> 'b) \<Rightarrow> 'a \<Rightarrow> 'b"
case_elem :: "'a \<Rightarrow> 'b \<Rightarrow> 'a \<Rightarrow> 'b"
case_abs :: "('c \<Rightarrow> 'b) \<Rightarrow> 'b"
-ML_file "Tools/Datatype/datatype_case.ML" setup Datatype_Case.setup
+ML_file "Tools/case_translation.ML"
+setup Case_Translation.setup
ML_file "Tools/Datatype/rep_datatype.ML"
ML_file "Tools/Datatype/datatype_codegen.ML" setup Datatype_Codegen.setup
@@ -297,7 +298,7 @@
fun fun_tr ctxt [cs] =
let
val x = Syntax.free (fst (Name.variant "x" (Term.declare_term_frees cs Name.context)));
- val ft = Datatype_Case.case_tr ctxt [x, cs];
+ val ft = Case_Translation.case_tr ctxt [x, cs];
in lambda x ft end
in [(@{syntax_const "_lam_pats_syntax"}, fun_tr)] end
*}
--- a/src/HOL/List.thy Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/List.thy Tue Jan 22 14:33:45 2013 +0100
@@ -407,7 +407,7 @@
Syntax.const @{syntax_const "_case1"} $
Syntax.const @{const_syntax dummy_pattern} $ NilC;
val cs = Syntax.const @{syntax_const "_case2"} $ case1 $ case2;
- in Syntax_Trans.abs_tr [x, Datatype_Case.case_tr ctxt [x, cs]] end;
+ in Syntax_Trans.abs_tr [x, Case_Translation.case_tr ctxt [x, cs]] end;
fun abs_tr ctxt p e opti =
(case Term_Position.strip_positions p of
--- a/src/HOL/Tools/Datatype/datatype_case.ML Tue Jan 22 13:32:41 2013 +0100
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,490 +0,0 @@
-(* Title: HOL/Tools/Datatype/datatype_case.ML
- Author: Konrad Slind, Cambridge University Computer Laboratory
- Author: Stefan Berghofer, TU Muenchen
-
-Datatype package: nested case expressions on datatypes.
-
-TODO:
- * Avoid hard-wiring with datatype package. Instead provide generic
- generic declarations of case splits based on an internal data slot.
-*)
-
-signature DATATYPE_CASE =
-sig
- datatype config = Error | Warning | Quiet
- type info = Datatype_Aux.info
- val case_tr: Proof.context -> term list -> term
- val make_case: Proof.context -> config -> Name.context -> term -> (term * term) list -> term
- val strip_case: Proof.context -> bool -> term -> term
- val show_cases: bool Config.T
- val setup: theory -> theory
-end;
-
-structure Datatype_Case : DATATYPE_CASE =
-struct
-
-datatype config = Error | Warning | Quiet;
-type info = Datatype_Aux.info;
-
-exception CASE_ERROR of string * int;
-
-fun match_type ctxt pat ob =
- Sign.typ_match (Proof_Context.theory_of ctxt) (pat, ob) Vartab.empty;
-
-(* Get information about datatypes *)
-
-fun ty_info ({descr, case_name, index, ...} : info) =
- let
- val (_, (tname, dts, constrs)) = nth descr index;
- val mk_ty = Datatype_Aux.typ_of_dtyp descr;
- val T = Type (tname, map mk_ty dts);
- in
- {case_name = case_name,
- constructors = map (fn (cname, dts') =>
- Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
- end;
-
-
-(*Each pattern carries with it a tag i, which denotes the clause it
-came from. i = ~1 indicates that the clause was added by pattern
-completion.*)
-
-fun add_row_used ((prfx, pats), (tm, tag)) =
- fold Term.declare_term_frees (tm :: pats @ map Free prfx);
-
-(* try to preserve names given by user *)
-fun default_name "" (Free (name', _)) = name'
- | default_name name _ = name;
-
-
-(*Produce an instance of a constructor, plus fresh variables for its arguments.*)
-fun fresh_constr ctxt colty used c =
- let
- val (_, T) = dest_Const c;
- val Ts = binder_types T;
- val (names, _) = fold_map Name.variant
- (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
- val ty = body_type T;
- val ty_theta = match_type ctxt ty colty
- handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
- val c' = Envir.subst_term_types ty_theta c;
- val gvars = map (Envir.subst_term_types ty_theta o Free) (names ~~ Ts);
- in (c', gvars) end;
-
-(*Go through a list of rows and pick out the ones beginning with a
- pattern with constructor = name.*)
-fun mk_group (name, T) rows =
- let val k = length (binder_types T) in
- fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
- fn ((in_group, not_in_group), names) =>
- (case strip_comb p of
- (Const (name', _), args) =>
- if name = name' then
- if length args = k then
- ((((prfx, args @ ps), rhs) :: in_group, not_in_group),
- map2 default_name names args)
- else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
- else ((in_group, row :: not_in_group), names)
- | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
- rows (([], []), replicate k "") |>> pairself rev
- end;
-
-
-(* Partitioning *)
-
-fun partition _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
- | partition ctxt used constructors colty res_ty
- (rows as (((prfx, _ :: ps), _) :: _)) =
- let
- fun part [] [] = []
- | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
- | part (c :: cs) rows =
- let
- val ((in_group, not_in_group), names) = mk_group (dest_Const c) rows;
- val used' = fold add_row_used in_group used;
- val (c', gvars) = fresh_constr ctxt colty used' c;
- val in_group' =
- if null in_group (* Constructor not given *)
- then
- let
- val Ts = map fastype_of ps;
- val (xs, _) =
- fold_map Name.variant
- (replicate (length ps) "x")
- (fold Term.declare_term_frees gvars used');
- in
- [((prfx, gvars @ map Free (xs ~~ Ts)),
- (Const (@{const_name undefined}, res_ty), ~1))]
- end
- else in_group;
- in
- {constructor = c',
- new_formals = gvars,
- names = names,
- group = in_group'} :: part cs not_in_group
- end;
- in part constructors rows end;
-
-fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
- | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
-
-
-(* Translation of pattern terms into nested case expressions. *)
-
-fun mk_case ctxt used range_ty =
- let
- val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
-
- fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1)
- | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
- if is_Free p then
- let
- val used' = add_row_used row used;
- fun expnd c =
- let val capp = list_comb (fresh_constr ctxt ty used' c)
- in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
- in map expnd constructors end
- else [row];
-
- val (name, _) = Name.variant "a" used;
-
- fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
- | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
- | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
- | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
- let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
- (case Option.map (apfst head_of)
- (find_first (not o is_Free o fst) col0) of
- NONE =>
- let
- val rows' = map (fn ((v, _), row) => row ||>
- apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
- in mk us rows' end
- | SOME (Const (cname, cT), i) =>
- (case Option.map ty_info (get_info (cname, cT)) of
- NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i)
- | SOME {case_name, constructors} =>
- let
- val pty = body_type cT;
- val used' = fold Term.declare_term_frees us used;
- val nrows = maps (expand constructors used' pty) rows;
- val subproblems =
- partition ctxt used' constructors pty range_ty nrows;
- val (pat_rect, dtrees) =
- split_list (map (fn {new_formals, group, ...} =>
- mk (new_formals @ us) group) subproblems);
- val case_functions =
- map2 (fn {new_formals, names, ...} =>
- fold_rev (fn (x as Free (_, T), s) => fn t =>
- Abs (if s = "" then name else s, T, abstract_over (x, t)))
- (new_formals ~~ names))
- subproblems dtrees;
- val types = map fastype_of (case_functions @ [u]);
- val case_const = Const (case_name, types ---> range_ty);
- val tree = list_comb (case_const, case_functions @ [u]);
- in (flat pat_rect, tree) end)
- | SOME (t, i) =>
- raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i))
- end
- | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
- in mk end;
-
-
-(* replace occurrences of dummy_pattern by distinct variables *)
-fun replace_dummies (Const (@{const_name dummy_pattern}, T)) used =
- let val (x, used') = Name.variant "x" used
- in (Free (x, T), used') end
- | replace_dummies (t $ u) used =
- let
- val (t', used') = replace_dummies t used;
- val (u', used'') = replace_dummies u used';
- in (t' $ u', used'') end
- | replace_dummies t used = (t, used);
-
-fun case_error s = error ("Error in case expression:\n" ^ s);
-
-(*Repeated variable occurrences in a pattern are not allowed.*)
-fun no_repeat_vars ctxt pat = fold_aterms
- (fn x as Free (s, _) =>
- (fn xs =>
- if member op aconv xs x then
- case_error (quote s ^ " occurs repeatedly in the pattern " ^
- quote (Syntax.string_of_term ctxt pat))
- else x :: xs)
- | _ => I) pat [];
-
-fun make_case ctxt config used x clauses =
- let
- fun string_of_clause (pat, rhs) =
- Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
- val _ = map (no_repeat_vars ctxt o fst) clauses;
- val (rows, used') = used |>
- fold (fn (pat, rhs) =>
- Term.declare_term_frees pat #> Term.declare_term_frees rhs) clauses |>
- fold_map (fn (i, (pat, rhs)) => fn used =>
- let val (pat', used') = replace_dummies pat used
- in ((([], [pat']), (rhs, i)), used') end)
- (map_index I clauses);
- val rangeT =
- (case distinct (op =) (map (fastype_of o snd) clauses) of
- [] => case_error "no clauses given"
- | [T] => T
- | _ => case_error "all cases must have the same result type");
- val used' = fold add_row_used rows used;
- val (tags, case_tm) =
- mk_case ctxt used' rangeT [x] rows
- handle CASE_ERROR (msg, i) =>
- case_error
- (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
- val _ =
- (case subtract (op =) tags (map (snd o snd) rows) of
- [] => ()
- | is =>
- (case config of Error => case_error | Warning => warning | Quiet => fn _ => ())
- ("The following clauses are redundant (covered by preceding clauses):\n" ^
- cat_lines (map (string_of_clause o nth clauses) is)));
- in
- case_tm
- end;
-
-
-(* term check *)
-
-fun decode_clause (Const (@{const_name case_abs}, _) $ Abs (s, T, t)) xs used =
- let val (s', used') = Name.variant s used
- in decode_clause t (Free (s', T) :: xs) used' end
- | decode_clause (Const (@{const_name case_elem}, _) $ t $ u) xs _ =
- (subst_bounds (xs, t), subst_bounds (xs, u))
- | decode_clause _ _ _ = case_error "decode_clause";
-
-fun decode_cases (Const (@{const_name case_nil}, _)) = []
- | decode_cases (Const (@{const_name case_cons}, _) $ t $ u) =
- decode_clause t [] (Term.declare_term_frees t Name.context) ::
- decode_cases u
- | decode_cases _ = case_error "decode_cases";
-
-fun check_case ctxt =
- let
- fun decode_case ((t as Const (@{const_name case_cons}, _) $ _ $ _) $ u) =
- make_case ctxt Error Name.context (decode_case u) (decode_cases t)
- | decode_case (t $ u) = decode_case t $ decode_case u
- | decode_case (Abs (x, T, u)) =
- let val (x', u') = Term.dest_abs (x, T, u);
- in Term.absfree (x', T) (decode_case u') end
- | decode_case t = t;
- in
- map decode_case
- end;
-
-val term_check_setup =
- Context.theory_map (Syntax_Phases.term_check 1 "case" check_case);
-
-
-(* parse translation *)
-
-fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
-
-fun case_tr ctxt [t, u] =
- let
- val thy = Proof_Context.theory_of ctxt;
-
- fun is_const s =
- Sign.declared_const thy (Proof_Context.intern_const ctxt s);
-
- fun abs p tTs t = Syntax.const @{const_syntax case_abs} $
- fold constrain_Abs tTs (absfree p t);
-
- fun abs_pat (Const ("_constrain", _) $ t $ tT) tTs = abs_pat t (tT :: tTs)
- | abs_pat (Free (p as (x, _))) tTs =
- if is_const x then I else abs p tTs
- | abs_pat (t $ u) _ = abs_pat u [] #> abs_pat t []
- | abs_pat _ _ = I;
-
- fun dest_case1 (Const (@{syntax_const "_case1"}, _) $ l $ r) =
- abs_pat l []
- (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l $ r)
- | dest_case1 _ = case_error "dest_case1";
-
- fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
- | dest_case2 t = [t];
- in
- fold_rev
- (fn t => fn u =>
- Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
- (dest_case2 u)
- (Syntax.const @{const_syntax case_nil}) $ t
- end
- | case_tr _ _ = case_error "case_tr";
-
-val trfun_setup =
- Sign.add_advanced_trfuns ([],
- [(@{syntax_const "_case_syntax"}, case_tr)],
- [], []);
-
-
-(* Pretty printing of nested case expressions *)
-
-val name_of = try (dest_Const #> fst);
-
-(* destruct one level of pattern matching *)
-
-fun dest_case ctxt d used t =
- (case apfst name_of (strip_comb t) of
- (SOME cname, ts as _ :: _) =>
- let
- val (fs, x) = split_last ts;
- fun strip_abs i Us t =
- let
- val zs = strip_abs_vars t;
- val j = length zs;
- val (xs, ys) =
- if j < i then (zs @ map (pair "x") (drop j Us), [])
- else chop i zs;
- val u = fold_rev Term.abs ys (strip_abs_body t);
- val xs' = map Free
- ((fold_map Name.variant (map fst xs)
- (Term.declare_term_names u used) |> fst) ~~
- map snd xs);
- val (xs1, xs2) = chop j xs'
- in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end;
- fun is_dependent i t =
- let val k = length (strip_abs_vars t) - i
- in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
- fun count_cases (_, _, true) = I
- | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c);
- val is_undefined = name_of #> equal (SOME @{const_name undefined});
- fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body);
- val get_info = Datatype_Data.info_of_case (Proof_Context.theory_of ctxt);
- in
- (case Option.map ty_info (get_info cname) of
- SOME {constructors, ...} =>
- if length fs = length constructors then
- let
- val cases = map (fn (Const (s, U), t) =>
- let
- val Us = binder_types U;
- val k = length Us;
- val p as (xs, _) = strip_abs k Us t;
- in
- (Const (s, map fastype_of xs ---> fastype_of x), p, is_dependent k t)
- end) (constructors ~~ fs);
- val cases' =
- sort (int_ord o swap o pairself (length o snd))
- (fold_rev count_cases cases []);
- val R = fastype_of t;
- val dummy =
- if d then Term.dummy_pattern R
- else Free (Name.variant "x" used |> fst, R);
- in
- SOME (x,
- map mk_case
- (case find_first (is_undefined o fst) cases' of
- SOME (_, cs) =>
- if length cs = length constructors then [hd cases]
- else filter_out (fn (_, (_, body), _) => is_undefined body) cases
- | NONE =>
- (case cases' of
- [] => cases
- | (default, cs) :: _ =>
- if length cs = 1 then cases
- else if length cs = length constructors then
- [hd cases, (dummy, ([], default), false)]
- else
- filter_out (fn (c, _, _) => member op aconv cs c) cases @
- [(dummy, ([], default), false)])))
- end
- else NONE
- | _ => NONE)
- end
- | _ => NONE);
-
-
-(* destruct nested patterns *)
-
-fun encode_clause S T (pat, rhs) =
- fold (fn x as (_, U) => fn t =>
- Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t)
- (Term.add_frees pat [])
- (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs);
-
-fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T)
- | encode_cases S T (p :: ps) =
- Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $
- encode_clause S T p $ encode_cases S T ps;
-
-fun encode_case (t, ps as (pat, rhs) :: _) =
- encode_cases (fastype_of pat) (fastype_of rhs) ps $ t
- | encode_case _ = case_error "encode_case";
-
-fun strip_case' ctxt d (pat, rhs) =
- (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of
- SOME (exp as Free _, clauses) =>
- if Term.exists_subterm (curry (op aconv) exp) pat andalso
- not (exists (fn (_, rhs') =>
- Term.exists_subterm (curry (op aconv) exp) rhs') clauses)
- then
- maps (strip_case' ctxt d) (map (fn (pat', rhs') =>
- (subst_free [(exp, pat')] pat, rhs')) clauses)
- else [(pat, rhs)]
- | _ => [(pat, rhs)]);
-
-fun strip_case ctxt d t =
- (case dest_case ctxt d Name.context t of
- SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses)
- | NONE =>
- (case t of
- (t $ u) => strip_case ctxt d t $ strip_case ctxt d u
- | (Abs (x, T, u)) =>
- let val (x', u') = Term.dest_abs (x, T, u);
- in Term.absfree (x', T) (strip_case ctxt d u') end
- | _ => t));
-
-
-(* term uncheck *)
-
-val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
-
-fun uncheck_case ctxt ts =
- if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts;
-
-val term_uncheck_setup =
- Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);
-
-
-(* print translation *)
-
-fun case_tr' [t, u, x] =
- let
- fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
- let val (s', used') = Name.variant s used
- in mk_clause t ((s', T) :: xs) used' end
- | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
- Syntax.const @{syntax_const "_case1"} $
- subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
- subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
-
- fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
- | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
- mk_clauses' t u
- and mk_clauses' t u =
- mk_clause t [] (Term.declare_term_frees t Name.context) ::
- mk_clauses u
- in
- Syntax.const @{syntax_const "_case_syntax"} $ x $
- foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
- (mk_clauses' t u)
- end;
-
-val trfun_setup' = Sign.add_trfuns
- ([], [], [(@{const_syntax "case_cons"}, case_tr')], []);
-
-
-(* theory setup *)
-
-val setup =
- trfun_setup #>
- trfun_setup' #>
- term_check_setup #>
- term_uncheck_setup;
-
-end;
--- a/src/HOL/Tools/Datatype/rep_datatype.ML Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/Tools/Datatype/rep_datatype.ML Tue Jan 22 14:33:45 2013 +0100
@@ -518,6 +518,12 @@
val unnamed_rules = map (fn induct =>
((Binding.empty, [Rule_Cases.inner_rule, Induct.induct_type ""]), [([induct], [])]))
(drop (length dt_names) inducts);
+
+ val ctxt = Proof_Context.init_global thy9;
+ val case_combs = map (Proof_Context.read_const ctxt false dummyT) case_names;
+ val constrss = map (fn (dtname, {descr, index, ...}) =>
+ map (Proof_Context.read_const ctxt false dummyT o fst)
+ (#3 (the (AList.lookup op = descr index)))) dt_infos
in
thy9
|> Global_Theory.note_thmss ""
@@ -535,6 +541,7 @@
named_rules @ unnamed_rules)
|> snd
|> Datatype_Data.register dt_infos
+ |> Context.theory_map (fold2 Case_Translation.register case_combs constrss)
|> Datatype_Data.interpretation_data (config, dt_names)
|> pair dt_names
end;
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Tue Jan 22 14:33:45 2013 +0100
@@ -640,7 +640,7 @@
val v = Free (name, T);
val v' = Free (name', T);
in
- lambda v (Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context v
+ lambda v (Case_Translation.make_case ctxt Case_Translation.Quiet Name.context v
[(HOLogic.mk_tuple out_ts,
if null eqs'' then success_t
else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
@@ -916,7 +916,7 @@
in
(pattern, compilation)
end
- val switch = Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context inp_var
+ val switch = Case_Translation.make_case ctxt Case_Translation.Quiet Name.context inp_var
((map compile_single_case switched_clauses) @
[(xt, mk_empty compfuns (HOLogic.mk_tupleT outTs))])
in
--- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML Tue Jan 22 14:33:45 2013 +0100
@@ -293,7 +293,7 @@
val cont_t = mk_smart_test_term' concl bound_vars' (new_assms @ assms) genuine
in
mk_test (vars_of lhs,
- Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context lhs
+ Case_Translation.make_case ctxt Case_Translation.Quiet Name.context lhs
[(list_comb (constr, vars), cont_t), (dummy_var, none_t)])
end
else c (assm, assms)
--- a/src/HOL/Tools/Quickcheck/narrowing_generators.ML Tue Jan 22 13:32:41 2013 +0100
+++ b/src/HOL/Tools/Quickcheck/narrowing_generators.ML Tue Jan 22 14:33:45 2013 +0100
@@ -427,7 +427,7 @@
end
fun mk_case_term ctxt p ((@{const_name Ex}, (x, T)) :: qs') (Existential_Counterexample cs) =
- Datatype_Case.make_case ctxt Datatype_Case.Quiet Name.context (Free (x, T)) (map (fn (t, c) =>
+ Case_Translation.make_case ctxt Case_Translation.Quiet Name.context (Free (x, T)) (map (fn (t, c) =>
(t, mk_case_term ctxt (p - 1) qs' c)) cs)
| mk_case_term ctxt p ((@{const_name All}, _) :: qs') (Universal_Counterexample (t, c)) =
if p = 0 then t else mk_case_term ctxt (p - 1) qs' c
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/case_translation.ML Tue Jan 22 14:33:45 2013 +0100
@@ -0,0 +1,590 @@
+(* Title: Tools/case_translation.ML
+ Author: Konrad Slind, Cambridge University Computer Laboratory
+ Author: Stefan Berghofer, TU Muenchen
+ Author: Dmitriy Traytel, TU Muenchen
+
+Nested case expressions via a generic data slot for case combinators and constructors.
+*)
+
+signature CASE_TRANSLATION =
+sig
+ datatype config = Error | Warning | Quiet
+ val case_tr: Proof.context -> term list -> term
+ val lookup_by_constr: Proof.context -> string * typ -> (term * term list) option
+ val lookup_by_constr_permissive: Proof.context -> string * typ -> (term * term list) option
+ val lookup_by_case: Proof.context -> string -> (term * term list) option
+ val make_case: Proof.context -> config -> Name.context -> term -> (term * term) list -> term
+ val print_case_translations: Proof.context -> unit
+ val strip_case: Proof.context -> bool -> term -> term
+ val show_cases: bool Config.T
+ val setup: theory -> theory
+ val register: term -> term list -> Context.generic -> Context.generic
+end;
+
+structure Case_Translation: CASE_TRANSLATION =
+struct
+
+(** data management **)
+
+datatype data = Data of
+ {constrs: (string * (term * term list)) list Symtab.table,
+ cases: (term * term list) Symtab.table};
+
+fun make_data (constrs, cases) = Data {constrs = constrs, cases = cases};
+
+structure Data = Generic_Data
+(
+ type T = data;
+ val empty = make_data (Symtab.empty, Symtab.empty);
+ val extend = I;
+ fun merge
+ (Data {constrs = constrs1, cases = cases1},
+ Data {constrs = constrs2, cases = cases2}) =
+ make_data
+ (Symtab.join (K (AList.merge (op =) (K true))) (constrs1, constrs2),
+ Symtab.merge (K true) (cases1, cases2));
+);
+
+fun map_data f =
+ Data.map (fn Data {constrs, cases} => make_data (f (constrs, cases)));
+fun map_constrs f = map_data (fn (constrs, cases) => (f constrs, cases));
+fun map_cases f = map_data (fn (constrs, cases) => (constrs, f cases));
+
+val rep_data = (fn Data args => args) o Data.get o Context.Proof;
+
+fun T_of_data (comb, constrs) =
+ fastype_of comb
+ |> funpow (length constrs) range_type
+ |> domain_type;
+
+val Tname_of_data = fst o dest_Type o T_of_data;
+
+val constrs_of = #constrs o rep_data;
+val cases_of = #cases o rep_data;
+
+fun lookup_by_constr ctxt (c, T) =
+ let
+ val tab = Symtab.lookup_list (constrs_of ctxt) c;
+ in
+ (case body_type T of
+ Type (tyco, _) => AList.lookup (op =) tab tyco
+ | _ => NONE)
+ end;
+
+fun lookup_by_constr_permissive ctxt (c, T) =
+ let
+ val tab = Symtab.lookup_list (constrs_of ctxt) c;
+ val hint = (case body_type T of Type (tyco, _) => SOME tyco | _ => NONE);
+ val default = if null tab then NONE else SOME (snd (List.last tab));
+ (*conservative wrt. overloaded constructors*)
+ in
+ (case hint of
+ NONE => default
+ | SOME tyco =>
+ (case AList.lookup (op =) tab tyco of
+ NONE => default (*permissive*)
+ | SOME info => SOME info))
+ end;
+
+val lookup_by_case = Symtab.lookup o cases_of;
+
+
+(** installation **)
+
+fun case_error s = error ("Error in case expression:\n" ^ s);
+
+val name_of = try (dest_Const #> fst);
+
+(* parse translation *)
+
+fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
+
+fun case_tr ctxt [t, u] =
+ let
+ val thy = Proof_Context.theory_of ctxt;
+
+ fun is_const s =
+ Sign.declared_const thy (Proof_Context.intern_const ctxt s);
+
+ fun abs p tTs t = Syntax.const @{const_syntax case_abs} $
+ fold constrain_Abs tTs (absfree p t);
+
+ fun abs_pat (Const ("_constrain", _) $ t $ tT) tTs = abs_pat t (tT :: tTs)
+ | abs_pat (Free (p as (x, _))) tTs =
+ if is_const x then I else abs p tTs
+ | abs_pat (t $ u) _ = abs_pat u [] #> abs_pat t []
+ | abs_pat _ _ = I;
+
+ fun dest_case1 (Const (@{syntax_const "_case1"}, _) $ l $ r) =
+ abs_pat l []
+ (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l $ r)
+ | dest_case1 _ = case_error "dest_case1";
+
+ fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
+ | dest_case2 t = [t];
+ in
+ fold_rev
+ (fn t => fn u =>
+ Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
+ (dest_case2 u)
+ (Syntax.const @{const_syntax case_nil}) $ t
+ end
+ | case_tr _ _ = case_error "case_tr";
+
+val trfun_setup =
+ Sign.add_advanced_trfuns ([],
+ [(@{syntax_const "_case_syntax"}, case_tr)],
+ [], []);
+
+
+(* print translation *)
+
+fun case_tr' [t, u, x] =
+ let
+ fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
+ let val (s', used') = Name.variant s used
+ in mk_clause t ((s', T) :: xs) used' end
+ | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
+ Syntax.const @{syntax_const "_case1"} $
+ subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
+ subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
+
+ fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
+ | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
+ mk_clauses' t u
+ and mk_clauses' t u =
+ mk_clause t [] (Term.declare_term_frees t Name.context) ::
+ mk_clauses u
+ in
+ Syntax.const @{syntax_const "_case_syntax"} $ x $
+ foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
+ (mk_clauses' t u)
+ end;
+
+val trfun_setup' = Sign.add_trfuns
+ ([], [], [(@{const_syntax "case_cons"}, case_tr')], []);
+
+
+(* declarations *)
+
+fun register raw_case_comb raw_constrs context =
+ let
+ val ctxt = Context.proof_of context;
+ val case_comb = singleton (Variable.polymorphic ctxt) raw_case_comb;
+ val constrs = Variable.polymorphic ctxt raw_constrs;
+ val case_key = case_comb |> dest_Const |> fst;
+ val constr_keys = map (fst o dest_Const) constrs;
+ val data = (case_comb, constrs);
+ val Tname = Tname_of_data data;
+ val update_constrs = fold (fn key => Symtab.cons_list (key, (Tname, data))) constr_keys;
+ val update_cases = Symtab.update (case_key, data);
+ in
+ context
+ |> map_constrs update_constrs
+ |> map_cases update_cases
+ end;
+
+
+(* (Un)check phases *)
+
+datatype config = Error | Warning | Quiet;
+
+exception CASE_ERROR of string * int;
+
+fun match_type ctxt pat ob =
+ Sign.typ_match (Proof_Context.theory_of ctxt) (pat, ob) Vartab.empty;
+
+
+(*Each pattern carries with it a tag i, which denotes the clause it
+came from. i = ~1 indicates that the clause was added by pattern
+completion.*)
+
+fun add_row_used ((prfx, pats), (tm, tag)) =
+ fold Term.declare_term_frees (tm :: pats @ map Free prfx);
+
+(* try to preserve names given by user *)
+fun default_name "" (Free (name', _)) = name'
+ | default_name name _ = name;
+
+
+(*Produce an instance of a constructor, plus fresh variables for its arguments.*)
+fun fresh_constr ctxt colty used c =
+ let
+ val (_, T) = dest_Const c;
+ val Ts = binder_types T;
+ val (names, _) = fold_map Name.variant
+ (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
+ val ty = body_type T;
+ val ty_theta = match_type ctxt ty colty
+ handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
+ val c' = Envir.subst_term_types ty_theta c;
+ val gvars = map (Envir.subst_term_types ty_theta o Free) (names ~~ Ts);
+ in (c', gvars) end;
+
+(*Go through a list of rows and pick out the ones beginning with a
+ pattern with constructor = name.*)
+fun mk_group (name, T) rows =
+ let val k = length (binder_types T) in
+ fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
+ fn ((in_group, not_in_group), names) =>
+ (case strip_comb p of
+ (Const (name', _), args) =>
+ if name = name' then
+ if length args = k then
+ ((((prfx, args @ ps), rhs) :: in_group, not_in_group),
+ map2 default_name names args)
+ else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
+ else ((in_group, row :: not_in_group), names)
+ | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
+ rows (([], []), replicate k "") |>> pairself rev
+ end;
+
+
+(* Partitioning *)
+
+fun partition _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
+ | partition ctxt used constructors colty res_ty
+ (rows as (((prfx, _ :: ps), _) :: _)) =
+ let
+ fun part [] [] = []
+ | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
+ | part (c :: cs) rows =
+ let
+ val ((in_group, not_in_group), names) = mk_group (dest_Const c) rows;
+ val used' = fold add_row_used in_group used;
+ val (c', gvars) = fresh_constr ctxt colty used' c;
+ val in_group' =
+ if null in_group (* Constructor not given *)
+ then
+ let
+ val Ts = map fastype_of ps;
+ val (xs, _) =
+ fold_map Name.variant
+ (replicate (length ps) "x")
+ (fold Term.declare_term_frees gvars used');
+ in
+ [((prfx, gvars @ map Free (xs ~~ Ts)),
+ (Const (@{const_name undefined}, res_ty), ~1))]
+ end
+ else in_group;
+ in
+ {constructor = c',
+ new_formals = gvars,
+ names = names,
+ group = in_group'} :: part cs not_in_group
+ end;
+ in part constructors rows end;
+
+fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
+ | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
+
+
+(* Translation of pattern terms into nested case expressions. *)
+
+fun mk_case ctxt used range_ty =
+ let
+ val get_info = lookup_by_constr_permissive ctxt;
+
+ fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1)
+ | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
+ if is_Free p then
+ let
+ val used' = add_row_used row used;
+ fun expnd c =
+ let val capp = list_comb (fresh_constr ctxt ty used' c)
+ in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
+ in map expnd constructors end
+ else [row];
+
+ val (name, _) = Name.variant "a" used;
+
+ fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
+ | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
+ | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
+ | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
+ let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
+ (case Option.map (apfst head_of)
+ (find_first (not o is_Free o fst) col0) of
+ NONE =>
+ let
+ val rows' = map (fn ((v, _), row) => row ||>
+ apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
+ in mk us rows' end
+ | SOME (Const (cname, cT), i) =>
+ (case get_info (cname, cT) of
+ NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i)
+ | SOME (case_comb, constructors) =>
+ let
+ val pty = body_type cT;
+ val used' = fold Term.declare_term_frees us used;
+ val nrows = maps (expand constructors used' pty) rows;
+ val subproblems =
+ partition ctxt used' constructors pty range_ty nrows;
+ val (pat_rect, dtrees) =
+ split_list (map (fn {new_formals, group, ...} =>
+ mk (new_formals @ us) group) subproblems);
+ val case_functions =
+ map2 (fn {new_formals, names, ...} =>
+ fold_rev (fn (x as Free (_, T), s) => fn t =>
+ Abs (if s = "" then name else s, T, abstract_over (x, t)))
+ (new_formals ~~ names))
+ subproblems dtrees;
+ val types = map fastype_of (case_functions @ [u]);
+ val case_const = Const (name_of case_comb |> the, types ---> range_ty);
+ val tree = list_comb (case_const, case_functions @ [u]);
+ in (flat pat_rect, tree) end)
+ | SOME (t, i) =>
+ raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i))
+ end
+ | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
+ in mk end;
+
+
+(* replace occurrences of dummy_pattern by distinct variables *)
+fun replace_dummies (Const (@{const_name dummy_pattern}, T)) used =
+ let val (x, used') = Name.variant "x" used
+ in (Free (x, T), used') end
+ | replace_dummies (t $ u) used =
+ let
+ val (t', used') = replace_dummies t used;
+ val (u', used'') = replace_dummies u used';
+ in (t' $ u', used'') end
+ | replace_dummies t used = (t, used);
+
+(*Repeated variable occurrences in a pattern are not allowed.*)
+fun no_repeat_vars ctxt pat = fold_aterms
+ (fn x as Free (s, _) =>
+ (fn xs =>
+ if member op aconv xs x then
+ case_error (quote s ^ " occurs repeatedly in the pattern " ^
+ quote (Syntax.string_of_term ctxt pat))
+ else x :: xs)
+ | _ => I) pat [];
+
+fun make_case ctxt config used x clauses =
+ let
+ fun string_of_clause (pat, rhs) =
+ Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
+ val _ = map (no_repeat_vars ctxt o fst) clauses;
+ val (rows, used') = used |>
+ fold (fn (pat, rhs) =>
+ Term.declare_term_frees pat #> Term.declare_term_frees rhs) clauses |>
+ fold_map (fn (i, (pat, rhs)) => fn used =>
+ let val (pat', used') = replace_dummies pat used
+ in ((([], [pat']), (rhs, i)), used') end)
+ (map_index I clauses);
+ val rangeT =
+ (case distinct (op =) (map (fastype_of o snd) clauses) of
+ [] => case_error "no clauses given"
+ | [T] => T
+ | _ => case_error "all cases must have the same result type");
+ val used' = fold add_row_used rows used;
+ val (tags, case_tm) =
+ mk_case ctxt used' rangeT [x] rows
+ handle CASE_ERROR (msg, i) =>
+ case_error
+ (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
+ val _ =
+ (case subtract (op =) tags (map (snd o snd) rows) of
+ [] => ()
+ | is =>
+ (case config of Error => case_error | Warning => warning | Quiet => fn _ => ())
+ ("The following clauses are redundant (covered by preceding clauses):\n" ^
+ cat_lines (map (string_of_clause o nth clauses) is)));
+ in
+ case_tm
+ end;
+
+
+(* term check *)
+
+fun decode_clause (Const (@{const_name case_abs}, _) $ Abs (s, T, t)) xs used =
+ let val (s', used') = Name.variant s used
+ in decode_clause t (Free (s', T) :: xs) used' end
+ | decode_clause (Const (@{const_name case_elem}, _) $ t $ u) xs _ =
+ (subst_bounds (xs, t), subst_bounds (xs, u))
+ | decode_clause _ _ _ = case_error "decode_clause";
+
+fun decode_cases (Const (@{const_name case_nil}, _)) = []
+ | decode_cases (Const (@{const_name case_cons}, _) $ t $ u) =
+ decode_clause t [] (Term.declare_term_frees t Name.context) ::
+ decode_cases u
+ | decode_cases _ = case_error "decode_cases";
+
+fun check_case ctxt =
+ let
+ fun decode_case ((t as Const (@{const_name case_cons}, _) $ _ $ _) $ u) =
+ make_case ctxt Error Name.context (decode_case u) (decode_cases t)
+ | decode_case (t $ u) = decode_case t $ decode_case u
+ | decode_case (Abs (x, T, u)) =
+ let val (x', u') = Term.dest_abs (x, T, u);
+ in Term.absfree (x', T) (decode_case u') end
+ | decode_case t = t;
+ in
+ map decode_case
+ end;
+
+val term_check_setup =
+ Context.theory_map (Syntax_Phases.term_check 1 "case" check_case);
+
+
+(* Pretty printing of nested case expressions *)
+
+(* destruct one level of pattern matching *)
+
+fun dest_case ctxt d used t =
+ (case apfst name_of (strip_comb t) of
+ (SOME cname, ts as _ :: _) =>
+ let
+ val (fs, x) = split_last ts;
+ fun strip_abs i Us t =
+ let
+ val zs = strip_abs_vars t;
+ val j = length zs;
+ val (xs, ys) =
+ if j < i then (zs @ map (pair "x") (drop j Us), [])
+ else chop i zs;
+ val u = fold_rev Term.abs ys (strip_abs_body t);
+ val xs' = map Free
+ ((fold_map Name.variant (map fst xs)
+ (Term.declare_term_names u used) |> fst) ~~
+ map snd xs);
+ val (xs1, xs2) = chop j xs'
+ in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end;
+ fun is_dependent i t =
+ let val k = length (strip_abs_vars t) - i
+ in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
+ fun count_cases (_, _, true) = I
+ | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c);
+ val is_undefined = name_of #> equal (SOME @{const_name undefined});
+ fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body);
+ val get_info = lookup_by_case ctxt;
+ in
+ (case get_info cname of
+ SOME (_, constructors) =>
+ if length fs = length constructors then
+ let
+ val cases = map (fn (Const (s, U), t) =>
+ let
+ val Us = binder_types U;
+ val k = length Us;
+ val p as (xs, _) = strip_abs k Us t;
+ in
+ (Const (s, map fastype_of xs ---> fastype_of x), p, is_dependent k t)
+ end) (constructors ~~ fs);
+ val cases' =
+ sort (int_ord o swap o pairself (length o snd))
+ (fold_rev count_cases cases []);
+ val R = fastype_of t;
+ val dummy =
+ if d then Term.dummy_pattern R
+ else Free (Name.variant "x" used |> fst, R);
+ in
+ SOME (x,
+ map mk_case
+ (case find_first (is_undefined o fst) cases' of
+ SOME (_, cs) =>
+ if length cs = length constructors then [hd cases]
+ else filter_out (fn (_, (_, body), _) => is_undefined body) cases
+ | NONE =>
+ (case cases' of
+ [] => cases
+ | (default, cs) :: _ =>
+ if length cs = 1 then cases
+ else if length cs = length constructors then
+ [hd cases, (dummy, ([], default), false)]
+ else
+ filter_out (fn (c, _, _) => member op aconv cs c) cases @
+ [(dummy, ([], default), false)])))
+ end
+ else NONE
+ | _ => NONE)
+ end
+ | _ => NONE);
+
+
+(* destruct nested patterns *)
+
+fun encode_clause S T (pat, rhs) =
+ fold (fn x as (_, U) => fn t =>
+ Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t)
+ (Term.add_frees pat [])
+ (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs);
+
+fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T)
+ | encode_cases S T (p :: ps) =
+ Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $
+ encode_clause S T p $ encode_cases S T ps;
+
+fun encode_case (t, ps as (pat, rhs) :: _) =
+ encode_cases (fastype_of pat) (fastype_of rhs) ps $ t
+ | encode_case _ = case_error "encode_case";
+
+fun strip_case' ctxt d (pat, rhs) =
+ (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of
+ SOME (exp as Free _, clauses) =>
+ if Term.exists_subterm (curry (op aconv) exp) pat andalso
+ not (exists (fn (_, rhs') =>
+ Term.exists_subterm (curry (op aconv) exp) rhs') clauses)
+ then
+ maps (strip_case' ctxt d) (map (fn (pat', rhs') =>
+ (subst_free [(exp, pat')] pat, rhs')) clauses)
+ else [(pat, rhs)]
+ | _ => [(pat, rhs)]);
+
+fun strip_case ctxt d t =
+ (case dest_case ctxt d Name.context t of
+ SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses)
+ | NONE =>
+ (case t of
+ (t $ u) => strip_case ctxt d t $ strip_case ctxt d u
+ | (Abs (x, T, u)) =>
+ let val (x', u') = Term.dest_abs (x, T, u);
+ in Term.absfree (x', T) (strip_case ctxt d u') end
+ | _ => t));
+
+
+(* term uncheck *)
+
+val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);
+
+fun uncheck_case ctxt ts =
+ if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts;
+
+val term_uncheck_setup =
+ Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case);
+
+
+(* theory setup *)
+
+val setup =
+ trfun_setup #>
+ trfun_setup' #>
+ term_check_setup #>
+ term_uncheck_setup;
+
+
+(* outer syntax commands *)
+
+fun print_case_translations ctxt =
+ let
+ val cases = Symtab.dest (cases_of ctxt);
+ fun show_case (_, data as (comb, ctrs)) =
+ Pretty.big_list
+ (Pretty.string_of (Pretty.block [Pretty.str (Tname_of_data data), Pretty.str ":"]))
+ [Pretty.block [Pretty.brk 3, Pretty.block
+ [Pretty.str "combinator:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt comb)]],
+ Pretty.block [Pretty.brk 3, Pretty.block
+ [Pretty.str "constructors:", Pretty.brk 1,
+ Pretty.list "" "" (map (Pretty.quote o Syntax.pretty_term ctxt) ctrs)]]];
+ in
+ Pretty.big_list "Case translations:" (map show_case cases)
+ |> Pretty.writeln
+ end;
+
+val _ =
+ Outer_Syntax.improper_command @{command_spec "print_case_translations"}
+ "print registered case combinators and constructors"
+ (Scan.succeed (Toplevel.keep (print_case_translations o Toplevel.context_of)))
+
+end;