src/HOLCF/ex/Pattern_Match.thy
changeset 37109 e67760c1b851
child 39557 fe5722fce758
equal deleted inserted replaced
37108:00f13d3ad474 37109:e67760c1b851
       
     1 (*  Title:      HOLCF/ex/Pattern_Match.thy
       
     2     Author:     Brian Huffman
       
     3 *)
       
     4 
       
     5 header {* An experimental pattern-matching notation *}
       
     6 
       
     7 theory Pattern_Match
       
     8 imports HOLCF
       
     9 begin
       
    10 
       
    11 text {* FIXME: Find a proper way to un-hide constants. *}
       
    12 
       
    13 abbreviation fail :: "'a match"
       
    14 where "fail \<equiv> Fixrec.fail"
       
    15 
       
    16 abbreviation succeed :: "'a \<rightarrow> 'a match"
       
    17 where "succeed \<equiv> Fixrec.succeed"
       
    18 
       
    19 abbreviation run :: "'a match \<rightarrow> 'a"
       
    20 where "run \<equiv> Fixrec.run"
       
    21 
       
    22 subsection {* Fatbar combinator *}
       
    23 
       
    24 definition
       
    25   fatbar :: "('a \<rightarrow> 'b match) \<rightarrow> ('a \<rightarrow> 'b match) \<rightarrow> ('a \<rightarrow> 'b match)" where
       
    26   "fatbar = (\<Lambda> a b x. a\<cdot>x +++ b\<cdot>x)"
       
    27 
       
    28 abbreviation
       
    29   fatbar_syn :: "['a \<rightarrow> 'b match, 'a \<rightarrow> 'b match] \<Rightarrow> 'a \<rightarrow> 'b match" (infixr "\<parallel>" 60)  where
       
    30   "m1 \<parallel> m2 == fatbar\<cdot>m1\<cdot>m2"
       
    31 
       
    32 lemma fatbar1: "m\<cdot>x = \<bottom> \<Longrightarrow> (m \<parallel> ms)\<cdot>x = \<bottom>"
       
    33 by (simp add: fatbar_def)
       
    34 
       
    35 lemma fatbar2: "m\<cdot>x = fail \<Longrightarrow> (m \<parallel> ms)\<cdot>x = ms\<cdot>x"
       
    36 by (simp add: fatbar_def)
       
    37 
       
    38 lemma fatbar3: "m\<cdot>x = succeed\<cdot>y \<Longrightarrow> (m \<parallel> ms)\<cdot>x = succeed\<cdot>y"
       
    39 by (simp add: fatbar_def)
       
    40 
       
    41 lemmas fatbar_simps = fatbar1 fatbar2 fatbar3
       
    42 
       
    43 lemma run_fatbar1: "m\<cdot>x = \<bottom> \<Longrightarrow> run\<cdot>((m \<parallel> ms)\<cdot>x) = \<bottom>"
       
    44 by (simp add: fatbar_def)
       
    45 
       
    46 lemma run_fatbar2: "m\<cdot>x = fail \<Longrightarrow> run\<cdot>((m \<parallel> ms)\<cdot>x) = run\<cdot>(ms\<cdot>x)"
       
    47 by (simp add: fatbar_def)
       
    48 
       
    49 lemma run_fatbar3: "m\<cdot>x = succeed\<cdot>y \<Longrightarrow> run\<cdot>((m \<parallel> ms)\<cdot>x) = y"
       
    50 by (simp add: fatbar_def)
       
    51 
       
    52 lemmas run_fatbar_simps [simp] = run_fatbar1 run_fatbar2 run_fatbar3
       
    53 
       
    54 subsection {* Case branch combinator *}
       
    55 
       
    56 definition
       
    57   branch :: "('a \<rightarrow> 'b match) \<Rightarrow> ('b \<rightarrow> 'c) \<rightarrow> ('a \<rightarrow> 'c match)" where
       
    58   "branch p \<equiv> \<Lambda> r x. match_case\<cdot>fail\<cdot>(\<Lambda> y. succeed\<cdot>(r\<cdot>y))\<cdot>(p\<cdot>x)"
       
    59 
       
    60 lemma branch_simps:
       
    61   "p\<cdot>x = \<bottom> \<Longrightarrow> branch p\<cdot>r\<cdot>x = \<bottom>"
       
    62   "p\<cdot>x = fail \<Longrightarrow> branch p\<cdot>r\<cdot>x = fail"
       
    63   "p\<cdot>x = succeed\<cdot>y \<Longrightarrow> branch p\<cdot>r\<cdot>x = succeed\<cdot>(r\<cdot>y)"
       
    64 by (simp_all add: branch_def)
       
    65 
       
    66 lemma branch_succeed [simp]: "branch succeed\<cdot>r\<cdot>x = succeed\<cdot>(r\<cdot>x)"
       
    67 by (simp add: branch_def)
       
    68 
       
    69 subsection {* Cases operator *}
       
    70 
       
    71 definition
       
    72   cases :: "'a match \<rightarrow> 'a::pcpo" where
       
    73   "cases = match_case\<cdot>\<bottom>\<cdot>ID"
       
    74 
       
    75 text {* rewrite rules for cases *}
       
    76 
       
    77 lemma cases_strict [simp]: "cases\<cdot>\<bottom> = \<bottom>"
       
    78 by (simp add: cases_def)
       
    79 
       
    80 lemma cases_fail [simp]: "cases\<cdot>fail = \<bottom>"
       
    81 by (simp add: cases_def)
       
    82 
       
    83 lemma cases_succeed [simp]: "cases\<cdot>(succeed\<cdot>x) = x"
       
    84 by (simp add: cases_def)
       
    85 
       
    86 subsection {* Case syntax *}
       
    87 
       
    88 nonterminals
       
    89   Case_syn  Cases_syn
       
    90 
       
    91 syntax
       
    92   "_Case_syntax":: "['a, Cases_syn] => 'b"               ("(Case _ of/ _)" 10)
       
    93   "_Case1"      :: "['a, 'b] => Case_syn"                ("(2_ =>/ _)" 10)
       
    94   ""            :: "Case_syn => Cases_syn"               ("_")
       
    95   "_Case2"      :: "[Case_syn, Cases_syn] => Cases_syn"  ("_/ | _")
       
    96 
       
    97 syntax (xsymbols)
       
    98   "_Case1"      :: "['a, 'b] => Case_syn"                ("(2_ \<Rightarrow>/ _)" 10)
       
    99 
       
   100 translations
       
   101   "_Case_syntax x ms" == "CONST cases\<cdot>(ms\<cdot>x)"
       
   102   "_Case2 m ms" == "m \<parallel> ms"
       
   103 
       
   104 text {* Parsing Case expressions *}
       
   105 
       
   106 syntax
       
   107   "_pat" :: "'a"
       
   108   "_variable" :: "'a"
       
   109   "_noargs" :: "'a"
       
   110 
       
   111 translations
       
   112   "_Case1 p r" => "CONST branch (_pat p)\<cdot>(_variable p r)"
       
   113   "_variable (_args x y) r" => "CONST csplit\<cdot>(_variable x (_variable y r))"
       
   114   "_variable _noargs r" => "CONST unit_when\<cdot>r"
       
   115 
       
   116 parse_translation {*
       
   117 (* rewrite (_pat x) => (succeed) *)
       
   118 (* rewrite (_variable x t) => (Abs_CFun (%x. t)) *)
       
   119  [(@{syntax_const "_pat"}, fn _ => Syntax.const @{const_syntax Fixrec.succeed}),
       
   120   mk_binder_tr (@{syntax_const "_variable"}, @{const_syntax Abs_CFun})];
       
   121 *}
       
   122 
       
   123 text {* Printing Case expressions *}
       
   124 
       
   125 syntax
       
   126   "_match" :: "'a"
       
   127 
       
   128 print_translation {*
       
   129   let
       
   130     fun dest_LAM (Const (@{const_syntax Rep_CFun},_) $ Const (@{const_syntax unit_when},_) $ t) =
       
   131           (Syntax.const @{syntax_const "_noargs"}, t)
       
   132     |   dest_LAM (Const (@{const_syntax Rep_CFun},_) $ Const (@{const_syntax csplit},_) $ t) =
       
   133           let
       
   134             val (v1, t1) = dest_LAM t;
       
   135             val (v2, t2) = dest_LAM t1;
       
   136           in (Syntax.const @{syntax_const "_args"} $ v1 $ v2, t2) end
       
   137     |   dest_LAM (Const (@{const_syntax Abs_CFun},_) $ t) =
       
   138           let
       
   139             val abs =
       
   140               case t of Abs abs => abs
       
   141                 | _ => ("x", dummyT, incr_boundvars 1 t $ Bound 0);
       
   142             val (x, t') = atomic_abs_tr' abs;
       
   143           in (Syntax.const @{syntax_const "_variable"} $ x, t') end
       
   144     |   dest_LAM _ = raise Match; (* too few vars: abort translation *)
       
   145 
       
   146     fun Case1_tr' [Const(@{const_syntax branch},_) $ p, r] =
       
   147           let val (v, t) = dest_LAM r in
       
   148             Syntax.const @{syntax_const "_Case1"} $
       
   149               (Syntax.const @{syntax_const "_match"} $ p $ v) $ t
       
   150           end;
       
   151 
       
   152   in [(@{const_syntax Rep_CFun}, Case1_tr')] end;
       
   153 *}
       
   154 
       
   155 translations
       
   156   "x" <= "_match (CONST succeed) (_variable x)"
       
   157 
       
   158 
       
   159 subsection {* Pattern combinators for data constructors *}
       
   160 
       
   161 types ('a, 'b) pat = "'a \<rightarrow> 'b match"
       
   162 
       
   163 definition
       
   164   cpair_pat :: "('a, 'c) pat \<Rightarrow> ('b, 'd) pat \<Rightarrow> ('a \<times> 'b, 'c \<times> 'd) pat" where
       
   165   "cpair_pat p1 p2 = (\<Lambda>(x, y).
       
   166     match_case\<cdot>fail\<cdot>(\<Lambda> a. match_case\<cdot>fail\<cdot>(\<Lambda> b. succeed\<cdot>(a, b))\<cdot>(p2\<cdot>y))\<cdot>(p1\<cdot>x))"
       
   167 
       
   168 definition
       
   169   spair_pat ::
       
   170   "('a, 'c) pat \<Rightarrow> ('b, 'd) pat \<Rightarrow> ('a::pcpo \<otimes> 'b::pcpo, 'c \<times> 'd) pat" where
       
   171   "spair_pat p1 p2 = (\<Lambda>(:x, y:). cpair_pat p1 p2\<cdot>(x, y))"
       
   172 
       
   173 definition
       
   174   sinl_pat :: "('a, 'c) pat \<Rightarrow> ('a::pcpo \<oplus> 'b::pcpo, 'c) pat" where
       
   175   "sinl_pat p = sscase\<cdot>p\<cdot>(\<Lambda> x. fail)"
       
   176 
       
   177 definition
       
   178   sinr_pat :: "('b, 'c) pat \<Rightarrow> ('a::pcpo \<oplus> 'b::pcpo, 'c) pat" where
       
   179   "sinr_pat p = sscase\<cdot>(\<Lambda> x. fail)\<cdot>p"
       
   180 
       
   181 definition
       
   182   up_pat :: "('a, 'b) pat \<Rightarrow> ('a u, 'b) pat" where
       
   183   "up_pat p = fup\<cdot>p"
       
   184 
       
   185 definition
       
   186   TT_pat :: "(tr, unit) pat" where
       
   187   "TT_pat = (\<Lambda> b. If b then succeed\<cdot>() else fail fi)"
       
   188 
       
   189 definition
       
   190   FF_pat :: "(tr, unit) pat" where
       
   191   "FF_pat = (\<Lambda> b. If b then fail else succeed\<cdot>() fi)"
       
   192 
       
   193 definition
       
   194   ONE_pat :: "(one, unit) pat" where
       
   195   "ONE_pat = (\<Lambda> ONE. succeed\<cdot>())"
       
   196 
       
   197 text {* Parse translations (patterns) *}
       
   198 translations
       
   199   "_pat (XCONST Pair x y)" => "CONST cpair_pat (_pat x) (_pat y)"
       
   200   "_pat (XCONST spair\<cdot>x\<cdot>y)" => "CONST spair_pat (_pat x) (_pat y)"
       
   201   "_pat (XCONST sinl\<cdot>x)" => "CONST sinl_pat (_pat x)"
       
   202   "_pat (XCONST sinr\<cdot>x)" => "CONST sinr_pat (_pat x)"
       
   203   "_pat (XCONST up\<cdot>x)" => "CONST up_pat (_pat x)"
       
   204   "_pat (XCONST TT)" => "CONST TT_pat"
       
   205   "_pat (XCONST FF)" => "CONST FF_pat"
       
   206   "_pat (XCONST ONE)" => "CONST ONE_pat"
       
   207 
       
   208 text {* CONST version is also needed for constructors with special syntax *}
       
   209 translations
       
   210   "_pat (CONST Pair x y)" => "CONST cpair_pat (_pat x) (_pat y)"
       
   211   "_pat (CONST spair\<cdot>x\<cdot>y)" => "CONST spair_pat (_pat x) (_pat y)"
       
   212 
       
   213 text {* Parse translations (variables) *}
       
   214 translations
       
   215   "_variable (XCONST Pair x y) r" => "_variable (_args x y) r"
       
   216   "_variable (XCONST spair\<cdot>x\<cdot>y) r" => "_variable (_args x y) r"
       
   217   "_variable (XCONST sinl\<cdot>x) r" => "_variable x r"
       
   218   "_variable (XCONST sinr\<cdot>x) r" => "_variable x r"
       
   219   "_variable (XCONST up\<cdot>x) r" => "_variable x r"
       
   220   "_variable (XCONST TT) r" => "_variable _noargs r"
       
   221   "_variable (XCONST FF) r" => "_variable _noargs r"
       
   222   "_variable (XCONST ONE) r" => "_variable _noargs r"
       
   223 
       
   224 translations
       
   225   "_variable (CONST Pair x y) r" => "_variable (_args x y) r"
       
   226   "_variable (CONST spair\<cdot>x\<cdot>y) r" => "_variable (_args x y) r"
       
   227 
       
   228 text {* Print translations *}
       
   229 translations
       
   230   "CONST Pair (_match p1 v1) (_match p2 v2)"
       
   231       <= "_match (CONST cpair_pat p1 p2) (_args v1 v2)"
       
   232   "CONST spair\<cdot>(_match p1 v1)\<cdot>(_match p2 v2)"
       
   233       <= "_match (CONST spair_pat p1 p2) (_args v1 v2)"
       
   234   "CONST sinl\<cdot>(_match p1 v1)" <= "_match (CONST sinl_pat p1) v1"
       
   235   "CONST sinr\<cdot>(_match p1 v1)" <= "_match (CONST sinr_pat p1) v1"
       
   236   "CONST up\<cdot>(_match p1 v1)" <= "_match (CONST up_pat p1) v1"
       
   237   "CONST TT" <= "_match (CONST TT_pat) _noargs"
       
   238   "CONST FF" <= "_match (CONST FF_pat) _noargs"
       
   239   "CONST ONE" <= "_match (CONST ONE_pat) _noargs"
       
   240 
       
   241 lemma cpair_pat1:
       
   242   "branch p\<cdot>r\<cdot>x = \<bottom> \<Longrightarrow> branch (cpair_pat p q)\<cdot>(csplit\<cdot>r)\<cdot>(x, y) = \<bottom>"
       
   243 apply (simp add: branch_def cpair_pat_def)
       
   244 apply (cases "p\<cdot>x", simp_all)
       
   245 done
       
   246 
       
   247 lemma cpair_pat2:
       
   248   "branch p\<cdot>r\<cdot>x = fail \<Longrightarrow> branch (cpair_pat p q)\<cdot>(csplit\<cdot>r)\<cdot>(x, y) = fail"
       
   249 apply (simp add: branch_def cpair_pat_def)
       
   250 apply (cases "p\<cdot>x", simp_all)
       
   251 done
       
   252 
       
   253 lemma cpair_pat3:
       
   254   "branch p\<cdot>r\<cdot>x = succeed\<cdot>s \<Longrightarrow>
       
   255    branch (cpair_pat p q)\<cdot>(csplit\<cdot>r)\<cdot>(x, y) = branch q\<cdot>s\<cdot>y"
       
   256 apply (simp add: branch_def cpair_pat_def)
       
   257 apply (cases "p\<cdot>x", simp_all)
       
   258 apply (cases "q\<cdot>y", simp_all)
       
   259 done
       
   260 
       
   261 lemmas cpair_pat [simp] =
       
   262   cpair_pat1 cpair_pat2 cpair_pat3
       
   263 
       
   264 lemma spair_pat [simp]:
       
   265   "branch (spair_pat p1 p2)\<cdot>r\<cdot>\<bottom> = \<bottom>"
       
   266   "\<lbrakk>x \<noteq> \<bottom>; y \<noteq> \<bottom>\<rbrakk>
       
   267      \<Longrightarrow> branch (spair_pat p1 p2)\<cdot>r\<cdot>(:x, y:) =
       
   268          branch (cpair_pat p1 p2)\<cdot>r\<cdot>(x, y)"
       
   269 by (simp_all add: branch_def spair_pat_def)
       
   270 
       
   271 lemma sinl_pat [simp]:
       
   272   "branch (sinl_pat p)\<cdot>r\<cdot>\<bottom> = \<bottom>"
       
   273   "x \<noteq> \<bottom> \<Longrightarrow> branch (sinl_pat p)\<cdot>r\<cdot>(sinl\<cdot>x) = branch p\<cdot>r\<cdot>x"
       
   274   "y \<noteq> \<bottom> \<Longrightarrow> branch (sinl_pat p)\<cdot>r\<cdot>(sinr\<cdot>y) = fail"
       
   275 by (simp_all add: branch_def sinl_pat_def)
       
   276 
       
   277 lemma sinr_pat [simp]:
       
   278   "branch (sinr_pat p)\<cdot>r\<cdot>\<bottom> = \<bottom>"
       
   279   "x \<noteq> \<bottom> \<Longrightarrow> branch (sinr_pat p)\<cdot>r\<cdot>(sinl\<cdot>x) = fail"
       
   280   "y \<noteq> \<bottom> \<Longrightarrow> branch (sinr_pat p)\<cdot>r\<cdot>(sinr\<cdot>y) = branch p\<cdot>r\<cdot>y"
       
   281 by (simp_all add: branch_def sinr_pat_def)
       
   282 
       
   283 lemma up_pat [simp]:
       
   284   "branch (up_pat p)\<cdot>r\<cdot>\<bottom> = \<bottom>"
       
   285   "branch (up_pat p)\<cdot>r\<cdot>(up\<cdot>x) = branch p\<cdot>r\<cdot>x"
       
   286 by (simp_all add: branch_def up_pat_def)
       
   287 
       
   288 lemma TT_pat [simp]:
       
   289   "branch TT_pat\<cdot>(unit_when\<cdot>r)\<cdot>\<bottom> = \<bottom>"
       
   290   "branch TT_pat\<cdot>(unit_when\<cdot>r)\<cdot>TT = succeed\<cdot>r"
       
   291   "branch TT_pat\<cdot>(unit_when\<cdot>r)\<cdot>FF = fail"
       
   292 by (simp_all add: branch_def TT_pat_def)
       
   293 
       
   294 lemma FF_pat [simp]:
       
   295   "branch FF_pat\<cdot>(unit_when\<cdot>r)\<cdot>\<bottom> = \<bottom>"
       
   296   "branch FF_pat\<cdot>(unit_when\<cdot>r)\<cdot>TT = fail"
       
   297   "branch FF_pat\<cdot>(unit_when\<cdot>r)\<cdot>FF = succeed\<cdot>r"
       
   298 by (simp_all add: branch_def FF_pat_def)
       
   299 
       
   300 lemma ONE_pat [simp]:
       
   301   "branch ONE_pat\<cdot>(unit_when\<cdot>r)\<cdot>\<bottom> = \<bottom>"
       
   302   "branch ONE_pat\<cdot>(unit_when\<cdot>r)\<cdot>ONE = succeed\<cdot>r"
       
   303 by (simp_all add: branch_def ONE_pat_def)
       
   304 
       
   305 
       
   306 subsection {* Wildcards, as-patterns, and lazy patterns *}
       
   307 
       
   308 definition
       
   309   wild_pat :: "'a \<rightarrow> unit match" where
       
   310   "wild_pat = (\<Lambda> x. succeed\<cdot>())"
       
   311 
       
   312 definition
       
   313   as_pat :: "('a \<rightarrow> 'b match) \<Rightarrow> 'a \<rightarrow> ('a \<times> 'b) match" where
       
   314   "as_pat p = (\<Lambda> x. match_case\<cdot>fail\<cdot>(\<Lambda> a. succeed\<cdot>(x, a))\<cdot>(p\<cdot>x))"
       
   315 
       
   316 definition
       
   317   lazy_pat :: "('a \<rightarrow> 'b::pcpo match) \<Rightarrow> ('a \<rightarrow> 'b match)" where
       
   318   "lazy_pat p = (\<Lambda> x. succeed\<cdot>(cases\<cdot>(p\<cdot>x)))"
       
   319 
       
   320 text {* Parse translations (patterns) *}
       
   321 translations
       
   322   "_pat _" => "CONST wild_pat"
       
   323 
       
   324 text {* Parse translations (variables) *}
       
   325 translations
       
   326   "_variable _ r" => "_variable _noargs r"
       
   327 
       
   328 text {* Print translations *}
       
   329 translations
       
   330   "_" <= "_match (CONST wild_pat) _noargs"
       
   331 
       
   332 lemma wild_pat [simp]: "branch wild_pat\<cdot>(unit_when\<cdot>r)\<cdot>x = succeed\<cdot>r"
       
   333 by (simp add: branch_def wild_pat_def)
       
   334 
       
   335 lemma as_pat [simp]:
       
   336   "branch (as_pat p)\<cdot>(csplit\<cdot>r)\<cdot>x = branch p\<cdot>(r\<cdot>x)\<cdot>x"
       
   337 apply (simp add: branch_def as_pat_def)
       
   338 apply (cases "p\<cdot>x", simp_all)
       
   339 done
       
   340 
       
   341 lemma lazy_pat [simp]:
       
   342   "branch p\<cdot>r\<cdot>x = \<bottom> \<Longrightarrow> branch (lazy_pat p)\<cdot>r\<cdot>x = succeed\<cdot>(r\<cdot>\<bottom>)"
       
   343   "branch p\<cdot>r\<cdot>x = fail \<Longrightarrow> branch (lazy_pat p)\<cdot>r\<cdot>x = succeed\<cdot>(r\<cdot>\<bottom>)"
       
   344   "branch p\<cdot>r\<cdot>x = succeed\<cdot>s \<Longrightarrow> branch (lazy_pat p)\<cdot>r\<cdot>x = succeed\<cdot>s"
       
   345 apply (simp_all add: branch_def lazy_pat_def)
       
   346 apply (cases "p\<cdot>x", simp_all)+
       
   347 done
       
   348 
       
   349 subsection {* Examples *}
       
   350 
       
   351 term "Case t of (:up\<cdot>(sinl\<cdot>x), sinr\<cdot>y:) \<Rightarrow> (x, y)"
       
   352 
       
   353 term "\<Lambda> t. Case t of up\<cdot>(sinl\<cdot>a) \<Rightarrow> a | up\<cdot>(sinr\<cdot>b) \<Rightarrow> b"
       
   354 
       
   355 term "\<Lambda> t. Case t of (:up\<cdot>(sinl\<cdot>_), sinr\<cdot>x:) \<Rightarrow> x"
       
   356 
       
   357 subsection {* ML code for generating definitions *}
       
   358 
       
   359 ML {*
       
   360 local open HOLCF_Library in
       
   361 
       
   362 val beta_rules =
       
   363   @{thms beta_cfun cont_id cont_const cont2cont_Rep_CFun cont2cont_LAM'} @
       
   364   @{thms cont2cont_fst cont2cont_snd cont2cont_Pair};
       
   365 
       
   366 val beta_ss = HOL_basic_ss addsimps (simp_thms @ beta_rules);
       
   367 
       
   368 fun define_consts
       
   369     (specs : (binding * term * mixfix) list)
       
   370     (thy : theory)
       
   371     : (term list * thm list) * theory =
       
   372   let
       
   373     fun mk_decl (b, t, mx) = (b, fastype_of t, mx);
       
   374     val decls = map mk_decl specs;
       
   375     val thy = Cont_Consts.add_consts decls thy;
       
   376     fun mk_const (b, T, mx) = Const (Sign.full_name thy b, T);
       
   377     val consts = map mk_const decls;
       
   378     fun mk_def c (b, t, mx) =
       
   379       (Binding.suffix_name "_def" b, Logic.mk_equals (c, t));
       
   380     val defs = map2 mk_def consts specs;
       
   381     val (def_thms, thy) =
       
   382       PureThy.add_defs false (map Thm.no_attributes defs) thy;
       
   383   in
       
   384     ((consts, def_thms), thy)
       
   385   end;
       
   386 
       
   387 fun prove
       
   388     (thy : theory)
       
   389     (defs : thm list)
       
   390     (goal : term)
       
   391     (tacs : {prems: thm list, context: Proof.context} -> tactic list)
       
   392     : thm =
       
   393   let
       
   394     fun tac {prems, context} =
       
   395       rewrite_goals_tac defs THEN
       
   396       EVERY (tacs {prems = map (rewrite_rule defs) prems, context = context})
       
   397   in
       
   398     Goal.prove_global thy [] [] goal tac
       
   399   end;
       
   400 
       
   401 fun get_vars_avoiding
       
   402     (taken : string list)
       
   403     (args : (bool * typ) list)
       
   404     : (term list * term list) =
       
   405   let
       
   406     val Ts = map snd args;
       
   407     val ns = Name.variant_list taken (Datatype_Prop.make_tnames Ts);
       
   408     val vs = map Free (ns ~~ Ts);
       
   409     val nonlazy = map snd (filter_out (fst o fst) (args ~~ vs));
       
   410   in
       
   411     (vs, nonlazy)
       
   412   end;
       
   413 
       
   414 (******************************************************************************)
       
   415 (************** definitions and theorems for pattern combinators **************)
       
   416 (******************************************************************************)
       
   417 
       
   418 fun add_pattern_combinators
       
   419     (bindings : binding list)
       
   420     (spec : (term * (bool * typ) list) list)
       
   421     (lhsT : typ)
       
   422     (exhaust : thm)
       
   423     (case_const : typ -> term)
       
   424     (case_rews : thm list)
       
   425     (thy : theory) =
       
   426   let
       
   427 
       
   428     (* utility functions *)
       
   429     fun mk_pair_pat (p1, p2) =
       
   430       let
       
   431         val T1 = fastype_of p1;
       
   432         val T2 = fastype_of p2;
       
   433         val (U1, V1) = apsnd dest_matchT (dest_cfunT T1);
       
   434         val (U2, V2) = apsnd dest_matchT (dest_cfunT T2);
       
   435         val pat_typ = [T1, T2] --->
       
   436             (mk_prodT (U1, U2) ->> mk_matchT (mk_prodT (V1, V2)));
       
   437         val pat_const = Const (@{const_name cpair_pat}, pat_typ);
       
   438       in
       
   439         pat_const $ p1 $ p2
       
   440       end;
       
   441     fun mk_tuple_pat [] = succeed_const HOLogic.unitT
       
   442       | mk_tuple_pat ps = foldr1 mk_pair_pat ps;
       
   443     fun branch_const (T,U,V) = 
       
   444       Const (@{const_name branch},
       
   445         (T ->> mk_matchT U) --> (U ->> V) ->> T ->> mk_matchT V);
       
   446 
       
   447     (* define pattern combinators *)
       
   448     local
       
   449       val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));
       
   450 
       
   451       fun pat_eqn (i, (bind, (con, args))) : binding * term * mixfix =
       
   452         let
       
   453           val pat_bind = Binding.suffix_name "_pat" bind;
       
   454           val Ts = map snd args;
       
   455           val Vs =
       
   456               (map (K "'t") args)
       
   457               |> Datatype_Prop.indexify_names
       
   458               |> Name.variant_list tns
       
   459               |> map (fn t => TFree (t, @{sort pcpo}));
       
   460           val patNs = Datatype_Prop.indexify_names (map (K "pat") args);
       
   461           val patTs = map2 (fn T => fn V => T ->> mk_matchT V) Ts Vs;
       
   462           val pats = map Free (patNs ~~ patTs);
       
   463           val fail = mk_fail (mk_tupleT Vs);
       
   464           val (vs, nonlazy) = get_vars_avoiding patNs args;
       
   465           val rhs = big_lambdas vs (mk_tuple_pat pats ` mk_tuple vs);
       
   466           fun one_fun (j, (_, args')) =
       
   467             let
       
   468               val (vs', nonlazy) = get_vars_avoiding patNs args';
       
   469             in if i = j then rhs else big_lambdas vs' fail end;
       
   470           val funs = map_index one_fun spec;
       
   471           val body = list_ccomb (case_const (mk_matchT (mk_tupleT Vs)), funs);
       
   472         in
       
   473           (pat_bind, lambdas pats body, NoSyn)
       
   474         end;
       
   475     in
       
   476       val ((pat_consts, pat_defs), thy) =
       
   477           define_consts (map_index pat_eqn (bindings ~~ spec)) thy
       
   478     end;
       
   479 
       
   480     (* syntax translations for pattern combinators *)
       
   481     local
       
   482       open Syntax
       
   483       fun syntax c = Syntax.mark_const (fst (dest_Const c));
       
   484       fun app s (l, r) = Syntax.mk_appl (Constant s) [l, r];
       
   485       val capp = app @{const_syntax Rep_CFun};
       
   486       val capps = Library.foldl capp
       
   487 
       
   488       fun app_var x = Syntax.mk_appl (Constant "_variable") [x, Variable "rhs"];
       
   489       fun app_pat x = Syntax.mk_appl (Constant "_pat") [x];
       
   490       fun args_list [] = Constant "_noargs"
       
   491         | args_list xs = foldr1 (app "_args") xs;
       
   492       fun one_case_trans (pat, (con, args)) =
       
   493         let
       
   494           val cname = Constant (syntax con);
       
   495           val pname = Constant (syntax pat);
       
   496           val ns = 1 upto length args;
       
   497           val xs = map (fn n => Variable ("x"^(string_of_int n))) ns;
       
   498           val ps = map (fn n => Variable ("p"^(string_of_int n))) ns;
       
   499           val vs = map (fn n => Variable ("v"^(string_of_int n))) ns;
       
   500         in
       
   501           [ParseRule (app_pat (capps (cname, xs)),
       
   502                       mk_appl pname (map app_pat xs)),
       
   503            ParseRule (app_var (capps (cname, xs)),
       
   504                       app_var (args_list xs)),
       
   505            PrintRule (capps (cname, ListPair.map (app "_match") (ps,vs)),
       
   506                       app "_match" (mk_appl pname ps, args_list vs))]
       
   507         end;
       
   508       val trans_rules : Syntax.ast Syntax.trrule list =
       
   509           maps one_case_trans (pat_consts ~~ spec);
       
   510     in
       
   511       val thy = Sign.add_trrules_i trans_rules thy;
       
   512     end;
       
   513 
       
   514     (* prove strictness and reduction rules of pattern combinators *)
       
   515     local
       
   516       val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));
       
   517       val rn = Name.variant tns "'r";
       
   518       val R = TFree (rn, @{sort pcpo});
       
   519       fun pat_lhs (pat, args) =
       
   520         let
       
   521           val Ts = map snd args;
       
   522           val Vs =
       
   523               (map (K "'t") args)
       
   524               |> Datatype_Prop.indexify_names
       
   525               |> Name.variant_list (rn::tns)
       
   526               |> map (fn t => TFree (t, @{sort pcpo}));
       
   527           val patNs = Datatype_Prop.indexify_names (map (K "pat") args);
       
   528           val patTs = map2 (fn T => fn V => T ->> mk_matchT V) Ts Vs;
       
   529           val pats = map Free (patNs ~~ patTs);
       
   530           val k = Free ("rhs", mk_tupleT Vs ->> R);
       
   531           val branch1 = branch_const (lhsT, mk_tupleT Vs, R);
       
   532           val fun1 = (branch1 $ list_comb (pat, pats)) ` k;
       
   533           val branch2 = branch_const (mk_tupleT Ts, mk_tupleT Vs, R);
       
   534           val fun2 = (branch2 $ mk_tuple_pat pats) ` k;
       
   535           val taken = "rhs" :: patNs;
       
   536         in (fun1, fun2, taken) end;
       
   537       fun pat_strict (pat, (con, args)) =
       
   538         let
       
   539           val (fun1, fun2, taken) = pat_lhs (pat, args);
       
   540           val defs = @{thm branch_def} :: pat_defs;
       
   541           val goal = mk_trp (mk_strict fun1);
       
   542           val rules = @{thms match_case_simps} @ case_rews;
       
   543           val tacs = [simp_tac (beta_ss addsimps rules) 1];
       
   544         in prove thy defs goal (K tacs) end;
       
   545       fun pat_apps (i, (pat, (con, args))) =
       
   546         let
       
   547           val (fun1, fun2, taken) = pat_lhs (pat, args);
       
   548           fun pat_app (j, (con', args')) =
       
   549             let
       
   550               val (vs, nonlazy) = get_vars_avoiding taken args';
       
   551               val con_app = list_ccomb (con', vs);
       
   552               val assms = map (mk_trp o mk_defined) nonlazy;
       
   553               val rhs = if i = j then fun2 ` mk_tuple vs else mk_fail R;
       
   554               val concl = mk_trp (mk_eq (fun1 ` con_app, rhs));
       
   555               val goal = Logic.list_implies (assms, concl);
       
   556               val defs = @{thm branch_def} :: pat_defs;
       
   557               val rules = @{thms match_case_simps} @ case_rews;
       
   558               val tacs = [asm_simp_tac (beta_ss addsimps rules) 1];
       
   559             in prove thy defs goal (K tacs) end;
       
   560         in map_index pat_app spec end;
       
   561     in
       
   562       val pat_stricts = map pat_strict (pat_consts ~~ spec);
       
   563       val pat_apps = flat (map_index pat_apps (pat_consts ~~ spec));
       
   564     end;
       
   565 
       
   566   in
       
   567     (pat_stricts @ pat_apps, thy)
       
   568   end
       
   569 
       
   570 end
       
   571 *}
       
   572 
       
   573 (*
       
   574 Cut from HOLCF/Tools/domain_constructors.ML
       
   575 in function add_domain_constructors:
       
   576 
       
   577     ( * define and prove theorems for pattern combinators * )
       
   578     val (pat_thms : thm list, thy : theory) =
       
   579       let
       
   580         val bindings = map #1 spec;
       
   581         fun prep_arg (lazy, sel, T) = (lazy, T);
       
   582         fun prep_con c (b, args, mx) = (c, map prep_arg args);
       
   583         val pat_spec = map2 prep_con con_consts spec;
       
   584       in
       
   585         add_pattern_combinators bindings pat_spec lhsT
       
   586           exhaust case_const cases thy
       
   587       end
       
   588 
       
   589 *)
       
   590 
       
   591 end