# HG changeset patch # User wenzelm # Date 1365614087 -7200 # Node ID 43a3465805dd5c7e689d58d3faeb855df870d44f # Parent baefa3b461c2531f0c2a39c37853f06e743b28c2# Parent 27ecd33d3366943ef450bc9210d5fff97e062667 merged diff -r 27ecd33d3366 -r 43a3465805dd CONTRIBUTORS --- a/CONTRIBUTORS Wed Apr 10 17:27:38 2013 +0200 +++ b/CONTRIBUTORS Wed Apr 10 19:14:47 2013 +0200 @@ -6,6 +6,12 @@ Contributions to this Isabelle version -------------------------------------- +* April 2013: Stefan Berghofer, secunet Security Networks AG + Dmitriy Traytel, TUM + Makarius Wenzel, Université Paris-Sud / LRI + Case translations as a separate check phase independent of the + datatype package. + * March 2013: Florian Haftmann, TUM Reform of "big operators" on sets. diff -r 27ecd33d3366 -r 43a3465805dd NEWS --- a/NEWS Wed Apr 10 17:27:38 2013 +0200 +++ b/NEWS Wed Apr 10 19:14:47 2013 +0200 @@ -43,6 +43,13 @@ *** HOL *** +* Nested case expressions are now translated in a separate check + phase rather than during parsing. The data for case combinators + is separated from the datatype package. The declaration attribute + "case_translation" can be used to register new case combinators: + + declare [[case_translation case_combinator constructor1 ... constructorN]] + * Notation "{p:A. P}" now allows tuple patterns as well. * Revised devices for recursive definitions over finite sets: diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/HOL.thy --- a/src/HOL/HOL.thy Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/HOL.thy Wed Apr 10 19:14:47 2013 +0200 @@ -8,7 +8,8 @@ imports Pure "~~/src/Tools/Code_Generator" keywords "try" "solve_direct" "quickcheck" - "print_coercions" "print_coercion_maps" "print_claset" "print_induct_rules" :: diag and + "print_coercions" "print_coercion_maps" "print_claset" "print_induct_rules" + "print_case_translations":: diag and "quickcheck_params" :: thy_decl begin diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Inductive.thy --- a/src/HOL/Inductive.thy Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Inductive.thy Wed Apr 10 19:14:47 2013 +0200 @@ -273,7 +273,21 @@ ML_file "Tools/Datatype/datatype_aux.ML" ML_file "Tools/Datatype/datatype_prop.ML" ML_file "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup -ML_file "Tools/Datatype/datatype_case.ML" setup Datatype_Case.setup + +consts + case_guard :: "bool \ 'a \ ('a \ 'b) \ 'b" + case_nil :: "'a \ 'b" + case_cons :: "('a \ 'b) \ ('a \ 'b) \ 'a \ 'b" + case_elem :: "'a \ 'b \ 'a \ 'b" + case_abs :: "('c \ 'b) \ 'b" +declare [[coercion_args case_guard - + -]] +declare [[coercion_args case_cons - -]] +declare [[coercion_args case_abs -]] +declare [[coercion_args case_elem - +]] + +ML_file "Tools/case_translation.ML" +setup Case_Translation.setup + ML_file "Tools/Datatype/rep_datatype.ML" ML_file "Tools/Datatype/datatype_codegen.ML" setup Datatype_Codegen.setup ML_file "Tools/Datatype/primrec.ML" @@ -290,7 +304,7 @@ fun fun_tr ctxt [cs] = let val x = Syntax.free (fst (Name.variant "x" (Term.declare_term_frees cs Name.context))); - val ft = Datatype_Case.case_tr true ctxt [x, cs]; + val ft = Case_Translation.case_tr true ctxt [x, cs]; in lambda x ft end in [(@{syntax_const "_lam_pats_syntax"}, fun_tr)] end *} diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/List.thy --- a/src/HOL/List.thy Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/List.thy Wed Apr 10 19:14:47 2013 +0200 @@ -407,7 +407,7 @@ Syntax.const @{syntax_const "_case1"} $ Syntax.const @{const_syntax dummy_pattern} $ NilC; val cs = Syntax.const @{syntax_const "_case2"} $ case1 $ case2; - in Syntax_Trans.abs_tr [x, Datatype_Case.case_tr false ctxt [x, cs]] end; + in Syntax_Trans.abs_tr [x, Case_Translation.case_tr false ctxt [x, cs]] end; fun abs_tr ctxt p e opti = (case Term_Position.strip_positions p of diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Probability/Borel_Space.thy --- a/src/HOL/Probability/Borel_Space.thy Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Probability/Borel_Space.thy Wed Apr 10 19:14:47 2013 +0200 @@ -136,13 +136,8 @@ interpret countable_basis using assms by unfold_locales fix X::"'a set" assume "open X" from open_countable_basisE[OF this] guess B' . note B' = this - show "X \ sigma_sets UNIV B" - proof cases - assume "B' \ {}" - thus "X \ sigma_sets UNIV B" using assms B' - by (metis from_nat_into Union_image_eq countable_subset range_from_nat_into - in_mono sigma_sets.Basic sigma_sets.Union) - qed (simp add: sigma_sets.Empty B') + then show "X \ sigma_sets UNIV B" + by (blast intro: sigma_sets_UNION `countable B` countable_subset) next fix b assume "b \ B" hence "open b" by (rule topological_basis_open[OF assms(2)]) @@ -206,22 +201,6 @@ using measurable_comp[OF g borel_measurable_continuous_on_open'[OF cont], of c] by (simp add: comp_def) -lemma continuous_on_fst: "continuous_on UNIV fst" -proof - - have [simp]: "range fst = UNIV" by (auto simp: image_iff) - show ?thesis - using closed_vimage_fst - by (auto simp: continuous_on_closed closed_closedin vimage_def) -qed - -lemma continuous_on_snd: "continuous_on UNIV snd" -proof - - have [simp]: "range snd = UNIV" by (auto simp: image_iff) - show ?thesis - using closed_vimage_snd - by (auto simp: continuous_on_closed closed_closedin vimage_def) -qed - lemma borel_measurable_continuous_Pair: fixes f :: "'a \ 'b::second_countable_topology" and g :: "'a \ 'c::second_countable_topology" assumes [measurable]: "f \ borel_measurable M" @@ -242,11 +221,10 @@ assumes "g \ borel_measurable M" shows "(\x. f x \ g x) \ borel_measurable M" using assms - by (rule borel_measurable_continuous_Pair) - (intro continuous_on_inner continuous_on_snd continuous_on_fst) + by (rule borel_measurable_continuous_Pair) (intro continuous_on_intros) lemma [measurable]: - fixes a b :: "'a\ordered_euclidean_space" + fixes a b :: "'a\linorder_topology" shows lessThan_borel: "{..< a} \ sets borel" and greaterThan_borel: "{a <..} \ sets borel" and greaterThanLessThan_borel: "{a<.. sets borel" @@ -256,22 +234,58 @@ and greaterThanAtMost_borel: "{a<..b} \ sets borel" and atLeastLessThan_borel: "{a.. sets borel" unfolding greaterThanAtMost_def atLeastLessThan_def + by (blast intro: borel_open borel_closed open_lessThan open_greaterThan open_greaterThanLessThan + closed_atMost closed_atLeast closed_atLeastAtMost)+ + +lemma eucl_ivals[measurable]: + fixes a b :: "'a\ordered_euclidean_space" + shows "{..< a} \ sets borel" + and "{a <..} \ sets borel" + and "{a<.. sets borel" + and "{..a} \ sets borel" + and "{a..} \ sets borel" + and "{a..b} \ sets borel" + and "{a<..b} \ sets borel" + and "{a.. sets borel" + unfolding greaterThanAtMost_def atLeastLessThan_def by (blast intro: borel_open borel_closed)+ +lemma open_Collect_less: + fixes f g :: "'i::topological_space \ 'a :: {inner_dense_linorder, linorder_topology}" + assumes "continuous_on UNIV f" + assumes "continuous_on UNIV g" + shows "open {x. f x < g x}" +proof - + have "open (\y. {x \ UNIV. f x \ {..< y}} \ {x \ UNIV. g x \ {y <..}})" (is "open ?X") + by (intro open_UN ballI open_Int continuous_open_preimage assms) auto + also have "?X = {x. f x < g x}" + by (auto intro: dense) + finally show ?thesis . +qed + +lemma closed_Collect_le: + fixes f g :: "'i::topological_space \ 'a :: {inner_dense_linorder, linorder_topology}" + assumes f: "continuous_on UNIV f" + assumes g: "continuous_on UNIV g" + shows "closed {x. f x \ g x}" + using open_Collect_less[OF g f] unfolding not_less[symmetric] Collect_neg_eq open_closed . + lemma borel_measurable_less[measurable]: - fixes f :: "'a \ real" - assumes f: "f \ borel_measurable M" - assumes g: "g \ borel_measurable M" + fixes f :: "'a \ 'b::{second_countable_topology, inner_dense_linorder, linorder_topology}" + assumes "f \ borel_measurable M" + assumes "g \ borel_measurable M" shows "{w \ space M. f w < g w} \ sets M" proof - - have "{w \ space M. f w < g w} = {x \ space M. \r. f x < of_rat r \ of_rat r < g x}" - using Rats_dense_in_real by (auto simp add: Rats_def) - with f g show ?thesis - by simp + have "{w \ space M. f w < g w} = (\x. (f x, g x)) -` {x. fst x < snd x} \ space M" + by auto + also have "\ \ sets M" + by (intro measurable_sets[OF borel_measurable_Pair borel_open, OF assms open_Collect_less] + continuous_on_intros) + finally show ?thesis . qed lemma - fixes f :: "'a \ real" + fixes f :: "'a \ 'b::{second_countable_topology, inner_dense_linorder, linorder_topology}" assumes f[measurable]: "f \ borel_measurable M" assumes g[measurable]: "g \ borel_measurable M" shows borel_measurable_le[measurable]: "{w \ space M. f w \ g w} \ sets M" @@ -281,10 +295,11 @@ by measurable lemma - shows hafspace_less_borel: "{x::'a::euclidean_space. a < x \ i} \ sets borel" - and hafspace_greater_borel: "{x::'a::euclidean_space. x \ i < a} \ sets borel" - and hafspace_less_eq_borel: "{x::'a::euclidean_space. a \ x \ i} \ sets borel" - and hafspace_greater_eq_borel: "{x::'a::euclidean_space. x \ i \ a} \ sets borel" + fixes i :: "'a::{second_countable_topology, real_inner}" + shows hafspace_less_borel: "{x. a < x \ i} \ sets borel" + and hafspace_greater_borel: "{x. x \ i < a} \ sets borel" + and hafspace_less_eq_borel: "{x. a \ x \ i} \ sets borel" + and hafspace_greater_eq_borel: "{x. x \ i \ a} \ sets borel" by simp_all subsection "Borel space equals sigma algebras over intervals" @@ -636,22 +651,20 @@ subsection "Borel measurable operators" lemma borel_measurable_uminus[measurable (raw)]: - fixes g :: "'a \ real" + fixes g :: "'a \ 'b::{second_countable_topology, real_normed_vector}" assumes g: "g \ borel_measurable M" shows "(\x. - g x) \ borel_measurable M" - by (rule borel_measurable_continuous_on[OF _ g]) (auto intro: continuous_on_minus continuous_on_id) + by (rule borel_measurable_continuous_on[OF _ g]) (intro continuous_on_intros) lemma borel_measurable_add[measurable (raw)]: - fixes f g :: "'a \ 'c::ordered_euclidean_space" + fixes f g :: "'a \ 'b::{second_countable_topology, real_normed_vector}" assumes f: "f \ borel_measurable M" assumes g: "g \ borel_measurable M" shows "(\x. f x + g x) \ borel_measurable M" - using f g - by (rule borel_measurable_continuous_Pair) - (auto intro: continuous_on_fst continuous_on_snd continuous_on_add) + using f g by (rule borel_measurable_continuous_Pair) (intro continuous_on_intros) lemma borel_measurable_setsum[measurable (raw)]: - fixes f :: "'c \ 'a \ real" + fixes f :: "'c \ 'a \ 'b::{second_countable_topology, real_normed_vector}" assumes "\i. i \ S \ f i \ borel_measurable M" shows "(\x. \i\S. f i x) \ borel_measurable M" proof cases @@ -660,37 +673,41 @@ qed simp lemma borel_measurable_diff[measurable (raw)]: - fixes f :: "'a \ real" + fixes f :: "'a \ 'b::{second_countable_topology, real_normed_vector}" assumes f: "f \ borel_measurable M" assumes g: "g \ borel_measurable M" shows "(\x. f x - g x) \ borel_measurable M" unfolding diff_minus using assms by simp lemma borel_measurable_times[measurable (raw)]: - fixes f :: "'a \ real" + fixes f :: "'a \ 'b::{second_countable_topology, real_normed_algebra}" assumes f: "f \ borel_measurable M" assumes g: "g \ borel_measurable M" shows "(\x. f x * g x) \ borel_measurable M" - using f g - by (rule borel_measurable_continuous_Pair) - (auto intro: continuous_on_fst continuous_on_snd continuous_on_mult) + using f g by (rule borel_measurable_continuous_Pair) (intro continuous_on_intros) + +lemma borel_measurable_setprod[measurable (raw)]: + fixes f :: "'c \ 'a \ 'b::{second_countable_topology, real_normed_field}" + assumes "\i. i \ S \ f i \ borel_measurable M" + shows "(\x. \i\S. f i x) \ borel_measurable M" +proof cases + assume "finite S" + thus ?thesis using assms by induct auto +qed simp lemma borel_measurable_dist[measurable (raw)]: - fixes g f :: "'a \ 'b::ordered_euclidean_space" + fixes g f :: "'a \ 'b::{second_countable_topology, metric_space}" assumes f: "f \ borel_measurable M" assumes g: "g \ borel_measurable M" shows "(\x. dist (f x) (g x)) \ borel_measurable M" - using f g - by (rule borel_measurable_continuous_Pair) - (intro continuous_on_dist continuous_on_fst continuous_on_snd) + using f g by (rule borel_measurable_continuous_Pair) (intro continuous_on_intros) lemma borel_measurable_scaleR[measurable (raw)]: - fixes g :: "'a \ 'b::ordered_euclidean_space" + fixes g :: "'a \ 'b::{second_countable_topology, real_normed_vector}" assumes f: "f \ borel_measurable M" assumes g: "g \ borel_measurable M" shows "(\x. f x *\<^sub>R g x) \ borel_measurable M" - by (rule borel_measurable_continuous_Pair[OF f g]) - (auto intro!: continuous_on_scaleR continuous_on_fst continuous_on_snd) + using f g by (rule borel_measurable_continuous_Pair) (intro continuous_on_intros) lemma affine_borel_measurable_vector: fixes f :: "'a \ 'x::real_normed_vector" @@ -720,36 +737,29 @@ "f \ borel_measurable M \ (\x. a + f x ::'a::real_normed_vector) \ borel_measurable M" using affine_borel_measurable_vector[of f M a 1] by simp -lemma borel_measurable_setprod[measurable (raw)]: - fixes f :: "'c \ 'a \ real" - assumes "\i. i \ S \ f i \ borel_measurable M" - shows "(\x. \i\S. f i x) \ borel_measurable M" -proof cases - assume "finite S" - thus ?thesis using assms by induct auto -qed simp - lemma borel_measurable_inverse[measurable (raw)]: - fixes f :: "'a \ real" + fixes f :: "'a \ 'b::{second_countable_topology, real_normed_div_algebra}" assumes f: "f \ borel_measurable M" shows "(\x. inverse (f x)) \ borel_measurable M" proof - - have "(\x::real. if x \ UNIV - {0} then inverse x else 0) \ borel_measurable borel" - by (intro borel_measurable_continuous_on_open' continuous_on_inverse continuous_on_id) auto - also have "(\x::real. if x \ UNIV - {0} then inverse x else 0) = inverse" by (intro ext) auto + have "(\x::'b. if x \ UNIV - {0} then inverse x else inverse 0) \ borel_measurable borel" + by (intro borel_measurable_continuous_on_open' continuous_on_intros) auto + also have "(\x::'b. if x \ UNIV - {0} then inverse x else inverse 0) = inverse" + by (intro ext) auto finally show ?thesis using f by simp qed lemma borel_measurable_divide[measurable (raw)]: - "f \ borel_measurable M \ g \ borel_measurable M \ (\x. f x / g x::real) \ borel_measurable M" + "f \ borel_measurable M \ g \ borel_measurable M \ + (\x. f x / g x::'b::{second_countable_topology, real_normed_field}) \ borel_measurable M" by (simp add: field_divide_inverse) lemma borel_measurable_max[measurable (raw)]: - "f \ borel_measurable M \ g \ borel_measurable M \ (\x. max (g x) (f x) :: real) \ borel_measurable M" + "f \ borel_measurable M \ g \ borel_measurable M \ (\x. max (g x) (f x) :: 'b::{second_countable_topology, inner_dense_linorder, linorder_topology}) \ borel_measurable M" by (simp add: max_def) lemma borel_measurable_min[measurable (raw)]: - "f \ borel_measurable M \ g \ borel_measurable M \ (\x. min (g x) (f x) :: real) \ borel_measurable M" + "f \ borel_measurable M \ g \ borel_measurable M \ (\x. min (g x) (f x) :: 'b::{second_countable_topology, inner_dense_linorder, linorder_topology}) \ borel_measurable M" by (simp add: min_def) lemma borel_measurable_abs[measurable (raw)]: @@ -761,15 +771,15 @@ by (simp add: cart_eq_inner_axis) lemma convex_measurable: - fixes a b :: real - assumes X: "X \ borel_measurable M" "X ` space M \ { a <..< b}" - assumes q: "convex_on { a <..< b} q" + fixes A :: "'a :: ordered_euclidean_space set" + assumes X: "X \ borel_measurable M" "X ` space M \ A" "open A" + assumes q: "convex_on A q" shows "(\x. q (X x)) \ borel_measurable M" proof - - have "(\x. if X x \ {a <..< b} then q (X x) else 0) \ borel_measurable M" (is "?qX") + have "(\x. if X x \ A then q (X x) else 0) \ borel_measurable M" (is "?qX") proof (rule borel_measurable_continuous_on_open[OF _ _ X(1)]) - show "open {a<.. (\x. q (X x)) \ borel_measurable M" @@ -904,19 +914,6 @@ by (subst *) (simp del: space_borel split del: split_if) qed -lemma [measurable]: - fixes f g :: "'a \ ereal" - assumes f: "f \ borel_measurable M" - assumes g: "g \ borel_measurable M" - shows borel_measurable_ereal_le: "{x \ space M. f x \ g x} \ sets M" - and borel_measurable_ereal_less: "{x \ space M. f x < g x} \ sets M" - and borel_measurable_ereal_eq: "{w \ space M. f w = g w} \ sets M" - using f g by (simp_all add: set_Collect_ereal2) - -lemma borel_measurable_ereal_neq: - "f \ borel_measurable M \ g \ borel_measurable M \ {w \ space M. f w \ (g w :: ereal)} \ sets M" - by simp - lemma borel_measurable_ereal_iff: shows "(\x. ereal (f x)) \ borel_measurable M \ f \ borel_measurable M" proof @@ -1197,4 +1194,4 @@ shows "(\x. suminf (\i. f i x)) \ borel_measurable M" unfolding suminf_def sums_def[abs_def] lim_def[symmetric] by simp -end +end diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Probability/Probability_Measure.thy --- a/src/HOL/Probability/Probability_Measure.thy Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Probability/Probability_Measure.thy Wed Apr 10 19:14:47 2013 +0200 @@ -524,7 +524,7 @@ by (subst AE_distr_iff) (auto dest!: distributed_AE simp: measurable_split_conv split_beta - intro!: measurable_Pair borel_measurable_ereal_le) + intro!: measurable_Pair) show 2: "random_variable (distr (S \\<^isub>M T) (T \\<^isub>M S) (\(x, y). (y, x))) (\x. (Y x, X x))" using Pxy by auto { fix A assume A: "A \ sets (T \\<^isub>M S)" @@ -657,7 +657,7 @@ show Pxy: "(\(x, y). Px x * Py y) \ borel_measurable (S \\<^isub>M T)" by auto show "AE x in S \\<^isub>M T. 0 \ (case x of (x, y) \ Px x * Py y)" - apply (intro ST.AE_pair_measure borel_measurable_ereal_le Pxy borel_measurable_const) + apply (intro ST.AE_pair_measure borel_measurable_le Pxy borel_measurable_const) using distributed_AE[OF X] apply eventually_elim using distributed_AE[OF Y] diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Probability/Sigma_Algebra.thy --- a/src/HOL/Probability/Sigma_Algebra.thy Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Probability/Sigma_Algebra.thy Wed Apr 10 19:14:47 2013 +0200 @@ -449,8 +449,7 @@ text {*Sigma algebras can naturally be created as the closure of any set of M with regard to the properties just postulated. *} -inductive_set - sigma_sets :: "'a set \ 'a set set \ 'a set set" +inductive_set sigma_sets :: "'a set \ 'a set set \ 'a set set" for sp :: "'a set" and A :: "'a set set" where Basic[intro, simp]: "a \ A \ a \ sigma_sets sp A" @@ -535,6 +534,13 @@ finally show ?thesis . qed +lemma sigma_sets_UNION: "countable B \ (\b. b \ B \ b \ sigma_sets X A) \ (\B) \ sigma_sets X A" + using from_nat_into[of B] range_from_nat_into[of B] sigma_sets.Union[of "from_nat_into B" X A] + apply (cases "B = {}") + apply (simp add: sigma_sets.Empty) + apply (simp del: Union_image_eq add: Union_image_eq[symmetric]) + done + lemma (in sigma_algebra) sigma_sets_eq: "sigma_sets \ M = M" proof diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Tools/Datatype/datatype_case.ML --- a/src/HOL/Tools/Datatype/datatype_case.ML Wed Apr 10 17:27:38 2013 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,483 +0,0 @@ -(* Title: HOL/Tools/Datatype/datatype_case.ML - Author: Konrad Slind, Cambridge University Computer Laboratory - Author: Stefan Berghofer, TU Muenchen - -Datatype package: nested case expressions on datatypes. - -TODO: - * Avoid fragile operations on syntax trees (with type constraints - getting in the way). Instead work with auxiliary "destructor" - constants in translations and introduce the actual case - combinators in a separate term check phase (similar to term - abbreviations). - - * Avoid hard-wiring with datatype package. Instead provide generic - generic declarations of case splits based on an internal data slot. -*) - -signature DATATYPE_CASE = -sig - datatype config = Error | Warning | Quiet - type info = Datatype_Aux.info - val make_case : Proof.context -> config -> string list -> term -> (term * term) list -> term - val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option - val case_tr: bool -> Proof.context -> term list -> term - val show_cases: bool Config.T - val case_tr': string -> Proof.context -> term list -> term - val add_case_tr' : string list -> theory -> theory - val setup: theory -> theory -end; - -structure Datatype_Case : DATATYPE_CASE = -struct - -datatype config = Error | Warning | Quiet; -type info = Datatype_Aux.info; - -exception CASE_ERROR of string * int; - -fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty; - -(* Get information about datatypes *) - -fun ty_info ({descr, case_name, index, ...} : info) = - let - val (_, (tname, dts, constrs)) = nth descr index; - val mk_ty = Datatype_Aux.typ_of_dtyp descr; - val T = Type (tname, map mk_ty dts); - in - {case_name = case_name, - constructors = map (fn (cname, dts') => - Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs} - end; - - -(*Each pattern carries with it a tag i, which denotes the clause it -came from. i = ~1 indicates that the clause was added by pattern -completion.*) - -fun add_row_used ((prfx, pats), (tm, tag)) = - fold Term.add_free_names (tm :: pats @ map Free prfx); - -fun default_name name (t, cs) = - let - val name' = if name = "" then (case t of Free (name', _) => name' | _ => name) else name; - val cs' = if is_Free t then cs else filter_out Term_Position.is_position cs; - in (name', cs') end; - -fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) = - strip_constraints t ||> cons tT - | strip_constraints t = (t, []); - -fun constrain tT t = Syntax.const @{syntax_const "_constrain"} $ t $ tT; -fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT; - - -(*Produce an instance of a constructor, plus fresh variables for its arguments.*) -fun fresh_constr ty_match ty_inst colty used c = - let - val (_, T) = dest_Const c; - val Ts = binder_types T; - val names = - Name.variant_list used (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)); - val ty = body_type T; - val ty_theta = ty_match ty colty - handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1); - val c' = ty_inst ty_theta c; - val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts); - in (c', gvars) end; - -fun strip_comb_positions tm = - let - fun result t ts = (Term_Position.strip_positions t, ts); - fun strip (t as Const (@{syntax_const "_constrain"}, _) $ _ $ _) ts = result t ts - | strip (f $ t) ts = strip f (t :: ts) - | strip t ts = result t ts; - in strip tm [] end; - -(*Go through a list of rows and pick out the ones beginning with a - pattern with constructor = name.*) -fun mk_group (name, T) rows = - let val k = length (binder_types T) in - fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) => - fn ((in_group, not_in_group), (names, cnstrts)) => - (case strip_comb_positions p of - (Const (name', _), args) => - if name = name' then - if length args = k then - let - val constraints' = map strip_constraints args; - val (args', cnstrts') = split_list constraints'; - val (names', cnstrts'') = split_list (map2 default_name names constraints'); - in - ((((prfx, args' @ ps), rhs) :: in_group, not_in_group), - (names', map2 append cnstrts cnstrts'')) - end - else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i) - else ((in_group, row :: not_in_group), (names, cnstrts)) - | _ => raise CASE_ERROR ("Not a constructor pattern", i))) - rows (([], []), (replicate k "", replicate k [])) |>> pairself rev - end; - - -(* Partitioning *) - -fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) - | partition ty_match ty_inst type_of used constructors colty res_ty - (rows as (((prfx, _ :: ps), _) :: _)) = - let - fun part [] [] = [] - | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i) - | part (c :: cs) rows = - let - val ((in_group, not_in_group), (names, cnstrts)) = mk_group (dest_Const c) rows; - val used' = fold add_row_used in_group used; - val (c', gvars) = fresh_constr ty_match ty_inst colty used' c; - val in_group' = - if null in_group (* Constructor not given *) - then - let - val Ts = map type_of ps; - val xs = - Name.variant_list - (fold Term.add_free_names gvars used') - (replicate (length ps) "x"); - in - [((prfx, gvars @ map Free (xs ~~ Ts)), - (Const (@{const_syntax undefined}, res_ty), ~1))] - end - else in_group; - in - {constructor = c', - new_formals = gvars, - names = names, - constraints = cnstrts, - group = in_group'} :: part cs not_in_group - end; - in part constructors rows end; - -fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats) - | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); - - -(* Translation of pattern terms into nested case expressions. *) - -fun mk_case ctxt ty_match ty_inst type_of used range_ty = - let - val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt); - - fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1) - | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) = - if is_Free p then - let - val used' = add_row_used row used; - fun expnd c = - let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c) - in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end; - in map expnd constructors end - else [row]; - - val name = singleton (Name.variant_list used) "a"; - - fun mk _ [] = raise CASE_ERROR ("no rows", ~1) - | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *) - | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row] - | mk (u :: us) (rows as ((_, _ :: _), _) :: _) = - let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in - (case Option.map (apfst (fst o strip_comb_positions)) - (find_first (not o is_Free o fst) col0) of - NONE => - let - val rows' = map (fn ((v, _), row) => row ||> - apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows); - in mk us rows' end - | SOME (Const (cname, cT), i) => - (case Option.map ty_info (get_info (cname, cT)) of - NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i) - | SOME {case_name, constructors} => - let - val pty = body_type cT; - val used' = fold Term.add_free_names us used; - val nrows = maps (expand constructors used' pty) rows; - val subproblems = - partition ty_match ty_inst type_of used' - constructors pty range_ty nrows; - val (pat_rect, dtrees) = - split_list (map (fn {new_formals, group, ...} => - mk (new_formals @ us) group) subproblems); - val case_functions = - map2 (fn {new_formals, names, constraints, ...} => - fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t => - Abs (if s = "" then name else s, T, abstract_over (x, t)) - |> fold constrain_Abs cnstrts) (new_formals ~~ names ~~ constraints)) - subproblems dtrees; - val types = map type_of (case_functions @ [u]); - val case_const = Const (case_name, types ---> range_ty); - val tree = list_comb (case_const, case_functions @ [u]); - in (flat pat_rect, tree) end) - | SOME (t, i) => - raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i)) - end - | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1) - in mk end; - -fun case_error s = error ("Error in case expression:\n" ^ s); - -local - -(*Repeated variable occurrences in a pattern are not allowed.*) -fun no_repeat_vars ctxt pat = fold_aterms - (fn x as Free (s, _) => - (fn xs => - if member op aconv xs x then - case_error (quote s ^ " occurs repeatedly in the pattern " ^ - quote (Syntax.string_of_term ctxt pat)) - else x :: xs) - | _ => I) (Term_Position.strip_positions pat) []; - -fun gen_make_case ty_match ty_inst type_of ctxt config used x clauses = - let - fun string_of_clause (pat, rhs) = - Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs); - val _ = map (no_repeat_vars ctxt o fst) clauses; - val rows = map_index (fn (i, (pat, rhs)) => (([], [pat]), (rhs, i))) clauses; - val rangeT = - (case distinct (op =) (map (type_of o snd) clauses) of - [] => case_error "no clauses given" - | [T] => T - | _ => case_error "all cases must have the same result type"); - val used' = fold add_row_used rows used; - val (tags, case_tm) = - mk_case ctxt ty_match ty_inst type_of used' rangeT [x] rows - handle CASE_ERROR (msg, i) => - case_error - (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i))); - val _ = - (case subtract (op =) tags (map (snd o snd) rows) of - [] => () - | is => - (case config of Error => case_error | Warning => warning | Quiet => fn _ => ()) - ("The following clauses are redundant (covered by preceding clauses):\n" ^ - cat_lines (map (string_of_clause o nth clauses) is))); - in - case_tm - end; - -in - -fun make_case ctxt = - gen_make_case (match_type (Proof_Context.theory_of ctxt)) - Envir.subst_term_types fastype_of ctxt; - -val make_case_untyped = - gen_make_case (K (K Vartab.empty)) (K (Term.map_types (K dummyT))) (K dummyT); - -end; - - -(* parse translation *) - -fun case_tr err ctxt [t, u] = - let - val thy = Proof_Context.theory_of ctxt; - val intern_const_syntax = Consts.intern_syntax (Proof_Context.consts_of ctxt); - - (* replace occurrences of dummy_pattern by distinct variables *) - (* internalize constant names *) - (* FIXME proper name context!? *) - fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used = - let val (t', used') = prep_pat t used - in (c $ t' $ tT, used') end - | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used = - let val x = singleton (Name.variant_list used) "x" - in (Free (x, T), x :: used) end - | prep_pat (Const (s, T)) used = (Const (intern_const_syntax s, T), used) - | prep_pat (v as Free (s, T)) used = - let val s' = Proof_Context.intern_const ctxt s in - if Sign.declared_const thy s' then (Const (s', T), used) - else (v, used) - end - | prep_pat (t $ u) used = - let - val (t', used') = prep_pat t used; - val (u', used'') = prep_pat u used'; - in (t' $ u', used'') end - | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t); - - fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) = - let val (l', cnstrts) = strip_constraints l - in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end - | dest_case1 t = case_error "dest_case1"; - - fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u - | dest_case2 t = [t]; - - val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); - in - make_case_untyped ctxt - (if err then Error else Warning) [] - (fold constrain (filter_out Term_Position.is_position (flat cnstrts)) t) - cases - end - | case_tr _ _ _ = case_error "case_tr"; - -val trfun_setup = - Sign.add_advanced_trfuns ([], - [(@{syntax_const "_case_syntax"}, case_tr true)], - [], []); - - -(* Pretty printing of nested case expressions *) - -(* destruct one level of pattern matching *) - -local - -fun gen_dest_case name_of type_of ctxt d used t = - (case apfst name_of (strip_comb t) of - (SOME cname, ts as _ :: _) => - let - val (fs, x) = split_last ts; - fun strip_abs i Us t = - let - val zs = strip_abs_vars t; - val j = length zs; - val (xs, ys) = - if j < i then (zs @ map (pair "x") (drop j Us), []) - else chop i zs; - val u = fold_rev Term.abs ys (strip_abs_body t); - val xs' = map Free - ((fold_map Name.variant (map fst xs) - (Term.declare_term_names u used) |> fst) ~~ - map snd xs); - val (xs1, xs2) = chop j xs' - in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end; - fun is_dependent i t = - let val k = length (strip_abs_vars t) - i - in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end; - fun count_cases (_, _, true) = I - | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c); - val is_undefined = name_of #> equal (SOME @{const_name undefined}); - fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body); - val get_info = Datatype_Data.info_of_case (Proof_Context.theory_of ctxt); - in - (case Option.map ty_info (get_info cname) of - SOME {constructors, ...} => - if length fs = length constructors then - let - val cases = map (fn (Const (s, U), t) => - let - val Us = binder_types U; - val k = length Us; - val p as (xs, _) = strip_abs k Us t; - in - (Const (s, map type_of xs ---> type_of x), p, is_dependent k t) - end) (constructors ~~ fs); - val cases' = - sort (int_ord o swap o pairself (length o snd)) - (fold_rev count_cases cases []); - val R = type_of t; - val dummy = - if d then Term.dummy_pattern R - else Free (Name.variant "x" used |> fst, R); - in - SOME (x, - map mk_case - (case find_first (is_undefined o fst) cases' of - SOME (_, cs) => - if length cs = length constructors then [hd cases] - else filter_out (fn (_, (_, body), _) => is_undefined body) cases - | NONE => - (case cases' of - [] => cases - | (default, cs) :: _ => - if length cs = 1 then cases - else if length cs = length constructors then - [hd cases, (dummy, ([], default), false)] - else - filter_out (fn (c, _, _) => member op aconv cs c) cases @ - [(dummy, ([], default), false)]))) - end - else NONE - | _ => NONE) - end - | _ => NONE); - -in - -val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of; -val dest_case' = gen_dest_case (try (dest_Const #> fst #> Lexicon.unmark_const)) (K dummyT); - -end; - - -(* destruct nested patterns *) - -local - -fun strip_case'' dest (pat, rhs) = - (case dest (Term.declare_term_frees pat Name.context) rhs of - SOME (exp as Free _, clauses) => - if Term.exists_subterm (curry (op aconv) exp) pat andalso - not (exists (fn (_, rhs') => - Term.exists_subterm (curry (op aconv) exp) rhs') clauses) - then - maps (strip_case'' dest) (map (fn (pat', rhs') => - (subst_free [(exp, pat')] pat, rhs')) clauses) - else [(pat, rhs)] - | _ => [(pat, rhs)]); - -fun gen_strip_case dest t = - (case dest Name.context t of - SOME (x, clauses) => SOME (x, maps (strip_case'' dest) clauses) - | NONE => NONE); - -in - -val strip_case = gen_strip_case oo dest_case; -val strip_case' = gen_strip_case oo dest_case'; - -end; - - -(* print translation *) - -val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true); - -fun case_tr' cname ctxt ts = - if Config.get ctxt show_cases then - let - fun mk_clause (pat, rhs) = - let val xs = Term.add_frees pat [] in - Syntax.const @{syntax_const "_case1"} $ - map_aterms - (fn Free p => Syntax_Trans.mark_bound_abs p - | Const (s, _) => Syntax.const (Lexicon.mark_const s) - | t => t) pat $ - map_aterms - (fn x as Free v => - if member (op =) xs v then Syntax_Trans.mark_bound_body v else x - | t => t) rhs - end; - in - (case strip_case' ctxt true (list_comb (Syntax.const cname, ts)) of - SOME (x, clauses) => - Syntax.const @{syntax_const "_case_syntax"} $ x $ - foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u) - (map mk_clause clauses) - | NONE => raise Match) - end - else raise Match; - -fun add_case_tr' case_names thy = - Sign.add_advanced_trfuns ([], [], - map (fn case_name => - let val case_name' = Lexicon.mark_const case_name - in (case_name', case_tr' case_name') end) case_names, []) thy; - - -(* theory setup *) - -val setup = trfun_setup; - -end; diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Tools/Datatype/rep_datatype.ML --- a/src/HOL/Tools/Datatype/rep_datatype.ML Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Tools/Datatype/rep_datatype.ML Wed Apr 10 19:14:47 2013 +0200 @@ -518,6 +518,12 @@ val unnamed_rules = map (fn induct => ((Binding.empty, [Rule_Cases.inner_rule, Induct.induct_type ""]), [([induct], [])])) (drop (length dt_names) inducts); + + val ctxt = Proof_Context.init_global thy9; + val case_combs = map (Proof_Context.read_const ctxt false dummyT) case_names; + val constrss = map (fn (dtname, {descr, index, ...}) => + map (Proof_Context.read_const ctxt false dummyT o fst) + (#3 (the (AList.lookup op = descr index)))) dt_infos in thy9 |> Global_Theory.note_thmss "" @@ -535,8 +541,8 @@ named_rules @ unnamed_rules) |> snd |> Datatype_Data.register dt_infos + |> Context.theory_map (fold2 Case_Translation.register case_combs constrss) |> Datatype_Data.interpretation_data (config, dt_names) - |> Datatype_Case.add_case_tr' case_names |> pair dt_names end; diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Wed Apr 10 19:14:47 2013 +0200 @@ -640,7 +640,7 @@ val v = Free (name, T); val v' = Free (name', T); in - lambda v (Datatype_Case.make_case ctxt Datatype_Case.Quiet [] v + lambda v (Case_Translation.make_case ctxt Case_Translation.Quiet Name.context v [(HOLogic.mk_tuple out_ts, if null eqs'' then success_t else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $ @@ -916,7 +916,7 @@ in (pattern, compilation) end - val switch = Datatype_Case.make_case ctxt Datatype_Case.Quiet [] inp_var + val switch = Case_Translation.make_case ctxt Case_Translation.Quiet Name.context inp_var ((map compile_single_case switched_clauses) @ [(xt, mk_empty compfuns (HOLogic.mk_tupleT outTs))]) in diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Tools/Quickcheck/exhaustive_generators.ML --- a/src/HOL/Tools/Quickcheck/exhaustive_generators.ML Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Tools/Quickcheck/exhaustive_generators.ML Wed Apr 10 19:14:47 2013 +0200 @@ -292,8 +292,9 @@ val bound_vars' = union (op =) (vars_of lhs) (union (op =) varnames bound_vars) val cont_t = mk_smart_test_term' concl bound_vars' (new_assms @ assms) genuine in - mk_test (vars_of lhs, Datatype_Case.make_case ctxt Datatype_Case.Quiet [] lhs - [(list_comb (constr, vars), cont_t), (dummy_var, none_t)]) + mk_test (vars_of lhs, + Case_Translation.make_case ctxt Case_Translation.Quiet Name.context lhs + [(list_comb (constr, vars), cont_t), (dummy_var, none_t)]) end else c (assm, assms) fun default (assm, assms) = diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Tools/Quickcheck/narrowing_generators.ML --- a/src/HOL/Tools/Quickcheck/narrowing_generators.ML Wed Apr 10 17:27:38 2013 +0200 +++ b/src/HOL/Tools/Quickcheck/narrowing_generators.ML Wed Apr 10 19:14:47 2013 +0200 @@ -427,7 +427,7 @@ end fun mk_case_term ctxt p ((@{const_name Ex}, (x, T)) :: qs') (Existential_Counterexample cs) = - Datatype_Case.make_case ctxt Datatype_Case.Quiet [] (Free (x, T)) (map (fn (t, c) => + Case_Translation.make_case ctxt Case_Translation.Quiet Name.context (Free (x, T)) (map (fn (t, c) => (t, mk_case_term ctxt (p - 1) qs' c)) cs) | mk_case_term ctxt p ((@{const_name All}, _) :: qs') (Universal_Counterexample (t, c)) = if p = 0 then t else mk_case_term ctxt (p - 1) qs' c diff -r 27ecd33d3366 -r 43a3465805dd src/HOL/Tools/case_translation.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/case_translation.ML Wed Apr 10 19:14:47 2013 +0200 @@ -0,0 +1,602 @@ +(* Title: Tools/case_translation.ML + Author: Konrad Slind, Cambridge University Computer Laboratory + Author: Stefan Berghofer, TU Muenchen + Author: Dmitriy Traytel, TU Muenchen + +Nested case expressions via a generic data slot for case combinators and constructors. +*) + +signature CASE_TRANSLATION = +sig + datatype config = Error | Warning | Quiet + val case_tr: bool -> Proof.context -> term list -> term + val lookup_by_constr: Proof.context -> string * typ -> (term * term list) option + val lookup_by_constr_permissive: Proof.context -> string * typ -> (term * term list) option + val lookup_by_case: Proof.context -> string -> (term * term list) option + val make_case: Proof.context -> config -> Name.context -> term -> (term * term) list -> term + val print_case_translations: Proof.context -> unit + val strip_case: Proof.context -> bool -> term -> (term * (term * term) list) option + val strip_case_full: Proof.context -> bool -> term -> term + val show_cases: bool Config.T + val setup: theory -> theory + val register: term -> term list -> Context.generic -> Context.generic +end; + +structure Case_Translation: CASE_TRANSLATION = +struct + +(** data management **) + +datatype data = Data of + {constrs: (string * (term * term list)) list Symtab.table, + cases: (term * term list) Symtab.table}; + +fun make_data (constrs, cases) = Data {constrs = constrs, cases = cases}; + +structure Data = Generic_Data +( + type T = data; + val empty = make_data (Symtab.empty, Symtab.empty); + val extend = I; + fun merge + (Data {constrs = constrs1, cases = cases1}, + Data {constrs = constrs2, cases = cases2}) = + make_data + (Symtab.join (K (AList.merge (op =) (K true))) (constrs1, constrs2), + Symtab.merge (K true) (cases1, cases2)); +); + +fun map_data f = + Data.map (fn Data {constrs, cases} => make_data (f (constrs, cases))); +fun map_constrs f = map_data (fn (constrs, cases) => (f constrs, cases)); +fun map_cases f = map_data (fn (constrs, cases) => (constrs, f cases)); + +val rep_data = (fn Data args => args) o Data.get o Context.Proof; + +fun T_of_data (comb, constrs) = + fastype_of comb + |> funpow (length constrs) range_type + |> domain_type; + +val Tname_of_data = fst o dest_Type o T_of_data; + +val constrs_of = #constrs o rep_data; +val cases_of = #cases o rep_data; + +fun lookup_by_constr ctxt (c, T) = + let + val tab = Symtab.lookup_list (constrs_of ctxt) c; + in + (case body_type T of + Type (tyco, _) => AList.lookup (op =) tab tyco + | _ => NONE) + end; + +fun lookup_by_constr_permissive ctxt (c, T) = + let + val tab = Symtab.lookup_list (constrs_of ctxt) c; + val hint = (case body_type T of Type (tyco, _) => SOME tyco | _ => NONE); + val default = if null tab then NONE else SOME (snd (List.last tab)); + (*conservative wrt. overloaded constructors*) + in + (case hint of + NONE => default + | SOME tyco => + (case AList.lookup (op =) tab tyco of + NONE => default (*permissive*) + | SOME info => SOME info)) + end; + +val lookup_by_case = Symtab.lookup o cases_of; + + +(** installation **) + +fun case_error s = error ("Error in case expression:\n" ^ s); + +val name_of = try (dest_Const #> fst); + +(* parse translation *) + +fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT; + +fun case_tr err ctxt [t, u] = + let + val thy = Proof_Context.theory_of ctxt; + + fun is_const s = + Sign.declared_const thy (Proof_Context.intern_const ctxt s); + + fun abs p tTs t = Syntax.const @{const_syntax case_abs} $ + fold constrain_Abs tTs (absfree p t); + + fun abs_pat (Const ("_constrain", _) $ t $ tT) tTs = abs_pat t (tT :: tTs) + | abs_pat (Free (p as (x, _))) tTs = + if is_const x then I else abs p tTs + | abs_pat (t $ u) _ = abs_pat u [] #> abs_pat t [] + | abs_pat _ _ = I; + + (* replace occurrences of dummy_pattern by distinct variables *) + fun replace_dummies (Const (@{const_syntax dummy_pattern}, T)) used = + let val (x, used') = Name.variant "x" used + in (Free (x, T), used') end + | replace_dummies (t $ u) used = + let + val (t', used') = replace_dummies t used; + val (u', used'') = replace_dummies u used'; + in (t' $ u', used'') end + | replace_dummies t used = (t, used); + + fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) = + let val (l', _) = replace_dummies l (Term.declare_term_frees t Name.context) + in abs_pat l' [] + (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l' $ r) + end + | dest_case1 _ = case_error "dest_case1"; + + fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u + | dest_case2 t = [t]; + + val errt = if err then @{term True} else @{term False}; + in + Syntax.const @{const_syntax case_guard} $ errt $ t $ (fold_rev + (fn t => fn u => + Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u) + (dest_case2 u) + (Syntax.const @{const_syntax case_nil})) + end + | case_tr _ _ _ = case_error "case_tr"; + +val trfun_setup = + Sign.add_advanced_trfuns ([], + [(@{syntax_const "_case_syntax"}, case_tr true)], + [], []); + + +(* print translation *) + +fun case_tr' [_, x, t] = + let + fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used = + let val (s', used') = Name.variant s used + in mk_clause t ((s', T) :: xs) used' end + | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ = + Syntax.const @{syntax_const "_case1"} $ + subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $ + subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs); + + fun mk_clauses (Const (@{const_syntax case_nil}, _)) = [] + | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) = + mk_clause t [] (Term.declare_term_frees t Name.context) :: + mk_clauses u + in + Syntax.const @{syntax_const "_case_syntax"} $ x $ + foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u) + (mk_clauses t) + end; + +val trfun_setup' = Sign.add_trfuns + ([], [], [(@{const_syntax "case_guard"}, case_tr')], []); + + +(* declarations *) + +fun register raw_case_comb raw_constrs context = + let + val ctxt = Context.proof_of context; + val case_comb = singleton (Variable.polymorphic ctxt) raw_case_comb; + val constrs = Variable.polymorphic ctxt raw_constrs; + val case_key = case_comb |> dest_Const |> fst; + val constr_keys = map (fst o dest_Const) constrs; + val data = (case_comb, constrs); + val Tname = Tname_of_data data; + val update_constrs = fold (fn key => Symtab.cons_list (key, (Tname, data))) constr_keys; + val update_cases = Symtab.update (case_key, data); + in + context + |> map_constrs update_constrs + |> map_cases update_cases + end; + + +(* (Un)check phases *) + +datatype config = Error | Warning | Quiet; + +exception CASE_ERROR of string * int; + + +(*Each pattern carries with it a tag i, which denotes the clause it +came from. i = ~1 indicates that the clause was added by pattern +completion.*) + +fun add_row_used ((prfx, pats), (tm, tag)) = + fold Term.declare_term_frees (tm :: pats @ map Free prfx); + +(* try to preserve names given by user *) +fun default_name "" (Free (name', _)) = name' + | default_name name _ = name; + + +(*Produce an instance of a constructor, plus fresh variables for its arguments.*) +fun fresh_constr colty used c = + let + val (_, T) = dest_Const c; + val Ts = binder_types T; + val (names, _) = fold_map Name.variant + (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used; + val ty = body_type T; + val ty_theta = Type.raw_match (ty, colty) Vartab.empty + handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1); + val c' = Envir.subst_term_types ty_theta c; + val gvars = map (Envir.subst_term_types ty_theta o Free) (names ~~ Ts); + in (c', gvars) end; + +(*Go through a list of rows and pick out the ones beginning with a + pattern with constructor = name.*) +fun mk_group (name, T) rows = + let val k = length (binder_types T) in + fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) => + fn ((in_group, not_in_group), names) => + (case strip_comb p of + (Const (name', _), args) => + if name = name' then + if length args = k then + ((((prfx, args @ ps), rhs) :: in_group, not_in_group), + map2 default_name names args) + else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i) + else ((in_group, row :: not_in_group), names) + | _ => raise CASE_ERROR ("Not a constructor pattern", i))) + rows (([], []), replicate k "") |>> pairself rev + end; + + +(* Partitioning *) + +fun partition _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) + | partition used constructors colty res_ty + (rows as (((prfx, _ :: ps), _) :: _)) = + let + fun part [] [] = [] + | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i) + | part (c :: cs) rows = + let + val ((in_group, not_in_group), names) = mk_group (dest_Const c) rows; + val used' = fold add_row_used in_group used; + val (c', gvars) = fresh_constr colty used' c; + val in_group' = + if null in_group (* Constructor not given *) + then + let + val Ts = map fastype_of ps; + val (xs, _) = + fold_map Name.variant + (replicate (length ps) "x") + (fold Term.declare_term_frees gvars used'); + in + [((prfx, gvars @ map Free (xs ~~ Ts)), + (Const (@{const_name undefined}, res_ty), ~1))] + end + else in_group; + in + {constructor = c', + new_formals = gvars, + names = names, + group = in_group'} :: part cs not_in_group + end; + in part constructors rows end; + +fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats) + | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); + + +(* Translation of pattern terms into nested case expressions. *) + +fun mk_case ctxt used range_ty = + let + val get_info = lookup_by_constr_permissive ctxt; + + fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1) + | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) = + if is_Free p then + let + val used' = add_row_used row used; + fun expnd c = + let val capp = list_comb (fresh_constr ty used' c) + in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end; + in map expnd constructors end + else [row]; + + val (name, _) = Name.variant "a" used; + + fun mk _ [] = raise CASE_ERROR ("no rows", ~1) + | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *) + | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row] + | mk (u :: us) (rows as ((_, _ :: _), _) :: _) = + let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in + (case Option.map (apfst head_of) + (find_first (not o is_Free o fst) col0) of + NONE => + let + val rows' = map (fn ((v, _), row) => row ||> + apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows); + in mk us rows' end + | SOME (Const (cname, cT), i) => + (case get_info (cname, cT) of + NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i) + | SOME (case_comb, constructors) => + let + val pty = body_type cT; + val used' = fold Term.declare_term_frees us used; + val nrows = maps (expand constructors used' pty) rows; + val subproblems = + partition used' constructors pty range_ty nrows; + val (pat_rect, dtrees) = + split_list (map (fn {new_formals, group, ...} => + mk (new_formals @ us) group) subproblems); + val case_functions = + map2 (fn {new_formals, names, ...} => + fold_rev (fn (x as Free (_, T), s) => fn t => + Abs (if s = "" then name else s, T, abstract_over (x, t))) + (new_formals ~~ names)) + subproblems dtrees; + val types = map fastype_of (case_functions @ [u]); + val case_const = Const (name_of case_comb |> the, types ---> range_ty); + val tree = list_comb (case_const, case_functions @ [u]); + in (flat pat_rect, tree) end) + | SOME (t, i) => + raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i)) + end + | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1) + in mk end; + + +(*Repeated variable occurrences in a pattern are not allowed.*) +fun no_repeat_vars ctxt pat = fold_aterms + (fn x as Free (s, _) => + (fn xs => + if member op aconv xs x then + case_error (quote s ^ " occurs repeatedly in the pattern " ^ + quote (Syntax.string_of_term ctxt pat)) + else x :: xs) + | _ => I) pat []; + +fun make_case ctxt config used x clauses = + let + fun string_of_clause (pat, rhs) = + Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs); + val _ = map (no_repeat_vars ctxt o fst) clauses; + val rows = map_index (fn (i, (pat, rhs)) => (([], [pat]), (rhs, i))) clauses; + val rangeT = + (case distinct (op =) (map (fastype_of o snd) clauses) of + [] => case_error "no clauses given" + | [T] => T + | _ => case_error "all cases must have the same result type"); + val used' = fold add_row_used rows used; + val (tags, case_tm) = + mk_case ctxt used' rangeT [x] rows + handle CASE_ERROR (msg, i) => + case_error + (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i))); + val _ = + (case subtract (op =) tags (map (snd o snd) rows) of + [] => () + | is => + (case config of Error => case_error | Warning => warning | Quiet => fn _ => ()) + ("The following clauses are redundant (covered by preceding clauses):\n" ^ + cat_lines (map (string_of_clause o nth clauses) is))); + in + case_tm + end; + + +(* term check *) + +fun decode_clause (Const (@{const_name case_abs}, _) $ Abs (s, T, t)) xs used = + let val (s', used') = Name.variant s used + in decode_clause t (Free (s', T) :: xs) used' end + | decode_clause (Const (@{const_name case_elem}, _) $ t $ u) xs _ = + (subst_bounds (xs, t), subst_bounds (xs, u)) + | decode_clause _ _ _ = case_error "decode_clause"; + +fun decode_cases (Const (@{const_name case_nil}, _)) = [] + | decode_cases (Const (@{const_name case_cons}, _) $ t $ u) = + decode_clause t [] (Term.declare_term_frees t Name.context) :: + decode_cases u + | decode_cases _ = case_error "decode_cases"; + +fun check_case ctxt = + let + fun decode_case (Const (@{const_name case_guard}, _) $ b $ u $ t) = + make_case ctxt (if b = @{term True} then Error else Warning) + Name.context (decode_case u) (decode_cases t) + | decode_case (t $ u) = decode_case t $ decode_case u + | decode_case (Abs (x, T, u)) = + let val (x', u') = Term.dest_abs (x, T, u); + in Term.absfree (x', T) (decode_case u') end + | decode_case t = t; + in + map decode_case + end; + +val term_check_setup = + Context.theory_map (Syntax_Phases.term_check 1 "case" check_case); + + +(* Pretty printing of nested case expressions *) + +(* destruct one level of pattern matching *) + +fun dest_case ctxt d used t = + (case apfst name_of (strip_comb t) of + (SOME cname, ts as _ :: _) => + let + val (fs, x) = split_last ts; + fun strip_abs i Us t = + let + val zs = strip_abs_vars t; + val j = length zs; + val (xs, ys) = + if j < i then (zs @ map (pair "x") (drop j Us), []) + else chop i zs; + val u = fold_rev Term.abs ys (strip_abs_body t); + val xs' = map Free + ((fold_map Name.variant (map fst xs) + (Term.declare_term_names u used) |> fst) ~~ + map snd xs); + val (xs1, xs2) = chop j xs' + in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end; + fun is_dependent i t = + let val k = length (strip_abs_vars t) - i + in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end; + fun count_cases (_, _, true) = I + | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c); + val is_undefined = name_of #> equal (SOME @{const_name undefined}); + fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body); + val get_info = lookup_by_case ctxt; + in + (case get_info cname of + SOME (_, constructors) => + if length fs = length constructors then + let + val cases = map (fn (Const (s, U), t) => + let + val Us = binder_types U; + val k = length Us; + val p as (xs, _) = strip_abs k Us t; + in + (Const (s, map fastype_of xs ---> fastype_of x), p, is_dependent k t) + end) (constructors ~~ fs); + val cases' = + sort (int_ord o swap o pairself (length o snd)) + (fold_rev count_cases cases []); + val R = fastype_of t; + val dummy = + if d then Term.dummy_pattern R + else Free (Name.variant "x" used |> fst, R); + in + SOME (x, + map mk_case + (case find_first (is_undefined o fst) cases' of + SOME (_, cs) => + if length cs = length constructors then [hd cases] + else filter_out (fn (_, (_, body), _) => is_undefined body) cases + | NONE => + (case cases' of + [] => cases + | (default, cs) :: _ => + if length cs = 1 then cases + else if length cs = length constructors then + [hd cases, (dummy, ([], default), false)] + else + filter_out (fn (c, _, _) => member op aconv cs c) cases @ + [(dummy, ([], default), false)]))) + end + else NONE + | _ => NONE) + end + | _ => NONE); + + +(* destruct nested patterns *) + +fun encode_clause recur S T (pat, rhs) = + fold (fn x as (_, U) => fn t => + Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t) + (Term.add_frees pat []) + (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ recur rhs); + +fun encode_cases _ S T [] = Const (@{const_name case_nil}, S --> T) + | encode_cases recur S T (p :: ps) = + Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $ + encode_clause recur S T p $ encode_cases recur S T ps; + +fun encode_case recur (t, ps as (pat, rhs) :: _) = + let + val tT = fastype_of t; + val T = fastype_of rhs; + in + Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $ + @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps) + end + | encode_case _ _ = case_error "encode_case"; + +fun strip_case' ctxt d (pat, rhs) = + (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of + SOME (exp as Free _, clauses) => + if Term.exists_subterm (curry (op aconv) exp) pat andalso + not (exists (fn (_, rhs') => + Term.exists_subterm (curry (op aconv) exp) rhs') clauses) + then + maps (strip_case' ctxt d) (map (fn (pat', rhs') => + (subst_free [(exp, pat')] pat, rhs')) clauses) + else [(pat, rhs)] + | _ => [(pat, rhs)]); + +fun strip_case ctxt d t = + (case dest_case ctxt d Name.context t of + SOME (x, clauses) => SOME (x, maps (strip_case' ctxt d) clauses) + | NONE => NONE); + +fun strip_case_full ctxt d t = + (case dest_case ctxt d Name.context t of + SOME (x, clauses) => + encode_case (strip_case_full ctxt d) + (strip_case_full ctxt d x, maps (strip_case' ctxt d) clauses) + | NONE => + (case t of + (t $ u) => strip_case_full ctxt d t $ strip_case_full ctxt d u + | (Abs (x, T, u)) => + let val (x', u') = Term.dest_abs (x, T, u); + in Term.absfree (x', T) (strip_case_full ctxt d u') end + | _ => t)); + + +(* term uncheck *) + +val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true); + +fun uncheck_case ctxt ts = + if Config.get ctxt show_cases then map (strip_case_full ctxt true) ts else ts; + +val term_uncheck_setup = + Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case); + + +(* theory setup *) + +val setup = + trfun_setup #> + trfun_setup' #> + term_check_setup #> + term_uncheck_setup #> + Attrib.setup @{binding case_translation} + (Args.term -- Scan.repeat1 Args.term >> + (fn (t, ts) => Thm.declaration_attribute (K (register t ts)))) + "declaration of case combinators and constructors"; + + +(* outer syntax commands *) + +fun print_case_translations ctxt = + let + val cases = Symtab.dest (cases_of ctxt); + fun show_case (_, data as (comb, ctrs)) = + Pretty.big_list + (Pretty.string_of (Pretty.block [Pretty.str (Tname_of_data data), Pretty.str ":"])) + [Pretty.block [Pretty.brk 3, Pretty.block + [Pretty.str "combinator:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt comb)]], + Pretty.block [Pretty.brk 3, Pretty.block + [Pretty.str "constructors:", Pretty.brk 1, + Pretty.list "" "" (map (Pretty.quote o Syntax.pretty_term ctxt) ctrs)]]]; + in + Pretty.big_list "Case translations:" (map show_case cases) + |> Pretty.writeln + end; + +val _ = + Outer_Syntax.improper_command @{command_spec "print_case_translations"} + "print registered case combinators and constructors" + (Scan.succeed (Toplevel.keep (print_case_translations o Toplevel.context_of))) + +end;