change definition of match combinators for fixrec package
authorhuffman
Sat, 11 Apr 2009 08:44:41 -0700
changeset 30912 4022298c1a86
parent 30911 7809cbaa1b61
child 30913 10b26965a08f
change definition of match combinators for fixrec package
src/HOLCF/Fixrec.thy
src/HOLCF/Tools/domain/domain_axioms.ML
src/HOLCF/Tools/domain/domain_syntax.ML
src/HOLCF/Tools/domain/domain_theorems.ML
src/HOLCF/Tools/fixrec_package.ML
--- a/src/HOLCF/Fixrec.thy	Fri Apr 10 17:39:53 2009 -0700
+++ b/src/HOLCF/Fixrec.thy	Sat Apr 11 08:44:41 2009 -0700
@@ -475,86 +475,95 @@
 defaultsort pcpo
 
 definition
-  match_UU :: "'a \<rightarrow> unit maybe" where
-  "match_UU = (\<Lambda> x. fail)"
+  match_UU :: "'a \<rightarrow> 'c maybe \<rightarrow> 'c maybe"
+where
+  "match_UU = strictify\<cdot>(\<Lambda> x k. fail)"
 
 definition
-  match_cpair :: "'a::cpo \<times> 'b::cpo \<rightarrow> ('a \<times> 'b) maybe" where
-  "match_cpair = csplit\<cdot>(\<Lambda> x y. return\<cdot><x,y>)"
+  match_cpair :: "'a::cpo \<times> 'b::cpo \<rightarrow> ('a \<rightarrow> 'b \<rightarrow> 'c maybe) \<rightarrow> 'c maybe"
+where
+  "match_cpair = (\<Lambda> x k. csplit\<cdot>k\<cdot>x)"
 
 definition
-  match_spair :: "'a \<otimes> 'b \<rightarrow> ('a \<times> 'b) maybe" where
-  "match_spair = ssplit\<cdot>(\<Lambda> x y. return\<cdot><x,y>)"
+  match_spair :: "'a \<otimes> 'b \<rightarrow> ('a \<rightarrow> 'b \<rightarrow> 'c maybe) \<rightarrow> 'c maybe"
+where
+  "match_spair = (\<Lambda> x k. ssplit\<cdot>k\<cdot>x)"
 
 definition
-  match_sinl :: "'a \<oplus> 'b \<rightarrow> 'a maybe" where
-  "match_sinl = sscase\<cdot>return\<cdot>(\<Lambda> y. fail)"
+  match_sinl :: "'a \<oplus> 'b \<rightarrow> ('a \<rightarrow> 'c maybe) \<rightarrow> 'c maybe"
+where
+  "match_sinl = (\<Lambda> x k. sscase\<cdot>k\<cdot>(\<Lambda> b. fail)\<cdot>x)"
 
 definition
-  match_sinr :: "'a \<oplus> 'b \<rightarrow> 'b maybe" where
-  "match_sinr = sscase\<cdot>(\<Lambda> x. fail)\<cdot>return"
+  match_sinr :: "'a \<oplus> 'b \<rightarrow> ('b \<rightarrow> 'c maybe) \<rightarrow> 'c maybe"
+where
+  "match_sinr = (\<Lambda> x k. sscase\<cdot>(\<Lambda> a. fail)\<cdot>k\<cdot>x)"
 
 definition
-  match_up :: "'a::cpo u \<rightarrow> 'a maybe" where
-  "match_up = fup\<cdot>return"
+  match_up :: "'a::cpo u \<rightarrow> ('a \<rightarrow> 'c maybe) \<rightarrow> 'c maybe"
+where
+  "match_up = (\<Lambda> x k. fup\<cdot>k\<cdot>x)"
 
 definition
-  match_ONE :: "one \<rightarrow> unit maybe" where
-  "match_ONE = (\<Lambda> ONE. return\<cdot>())"
+  match_ONE :: "one \<rightarrow> 'c maybe \<rightarrow> 'c maybe"
+where
+  "match_ONE = (\<Lambda> ONE k. k)"
+
+definition
+  match_TT :: "tr \<rightarrow> 'c maybe \<rightarrow> 'c maybe"
+where
+  "match_TT = (\<Lambda> x k. If x then k else fail fi)"
  
 definition
-  match_TT :: "tr \<rightarrow> unit maybe" where
-  "match_TT = (\<Lambda> b. If b then return\<cdot>() else fail fi)"
- 
-definition
-  match_FF :: "tr \<rightarrow> unit maybe" where
-  "match_FF = (\<Lambda> b. If b then fail else return\<cdot>() fi)"
+  match_FF :: "tr \<rightarrow> 'c maybe \<rightarrow> 'c maybe"
+where
+  "match_FF = (\<Lambda> x k. If x then fail else k fi)"
 
 lemma match_UU_simps [simp]:
-  "match_UU\<cdot>x = fail"
+  "match_UU\<cdot>x\<cdot>k = (if x = \<bottom> then \<bottom> else fail)"
 by (simp add: match_UU_def)
 
 lemma match_cpair_simps [simp]:
-  "match_cpair\<cdot><x,y> = return\<cdot><x,y>"
+  "match_cpair\<cdot>\<langle>x, y\<rangle>\<cdot>k = k\<cdot>x\<cdot>y"
 by (simp add: match_cpair_def)
 
 lemma match_spair_simps [simp]:
-  "\<lbrakk>x \<noteq> \<bottom>; y \<noteq> \<bottom>\<rbrakk> \<Longrightarrow> match_spair\<cdot>(:x,y:) = return\<cdot><x,y>"
-  "match_spair\<cdot>\<bottom> = \<bottom>"
+  "\<lbrakk>x \<noteq> \<bottom>; y \<noteq> \<bottom>\<rbrakk> \<Longrightarrow> match_spair\<cdot>(:x, y:)\<cdot>k = k\<cdot>x\<cdot>y"
+  "match_spair\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_spair_def)
 
 lemma match_sinl_simps [simp]:
-  "x \<noteq> \<bottom> \<Longrightarrow> match_sinl\<cdot>(sinl\<cdot>x) = return\<cdot>x"
-  "x \<noteq> \<bottom> \<Longrightarrow> match_sinl\<cdot>(sinr\<cdot>x) = fail"
-  "match_sinl\<cdot>\<bottom> = \<bottom>"
+  "x \<noteq> \<bottom> \<Longrightarrow> match_sinl\<cdot>(sinl\<cdot>x)\<cdot>k = k\<cdot>x"
+  "x \<noteq> \<bottom> \<Longrightarrow> match_sinl\<cdot>(sinr\<cdot>x)\<cdot>k = fail"
+  "match_sinl\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_sinl_def)
 
 lemma match_sinr_simps [simp]:
-  "x \<noteq> \<bottom> \<Longrightarrow> match_sinr\<cdot>(sinr\<cdot>x) = return\<cdot>x"
-  "x \<noteq> \<bottom> \<Longrightarrow> match_sinr\<cdot>(sinl\<cdot>x) = fail"
-  "match_sinr\<cdot>\<bottom> = \<bottom>"
+  "x \<noteq> \<bottom> \<Longrightarrow> match_sinr\<cdot>(sinr\<cdot>x)\<cdot>k = k\<cdot>x"
+  "x \<noteq> \<bottom> \<Longrightarrow> match_sinr\<cdot>(sinl\<cdot>x)\<cdot>k = fail"
+  "match_sinr\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_sinr_def)
 
 lemma match_up_simps [simp]:
-  "match_up\<cdot>(up\<cdot>x) = return\<cdot>x"
-  "match_up\<cdot>\<bottom> = \<bottom>"
+  "match_up\<cdot>(up\<cdot>x)\<cdot>k = k\<cdot>x"
+  "match_up\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_up_def)
 
 lemma match_ONE_simps [simp]:
-  "match_ONE\<cdot>ONE = return\<cdot>()"
-  "match_ONE\<cdot>\<bottom> = \<bottom>"
+  "match_ONE\<cdot>ONE\<cdot>k = k"
+  "match_ONE\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_ONE_def)
 
 lemma match_TT_simps [simp]:
-  "match_TT\<cdot>TT = return\<cdot>()"
-  "match_TT\<cdot>FF = fail"
-  "match_TT\<cdot>\<bottom> = \<bottom>"
+  "match_TT\<cdot>TT\<cdot>k = k"
+  "match_TT\<cdot>FF\<cdot>k = fail"
+  "match_TT\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_TT_def)
 
 lemma match_FF_simps [simp]:
-  "match_FF\<cdot>FF = return\<cdot>()"
-  "match_FF\<cdot>TT = fail"
-  "match_FF\<cdot>\<bottom> = \<bottom>"
+  "match_FF\<cdot>FF\<cdot>k = k"
+  "match_FF\<cdot>TT\<cdot>k = fail"
+  "match_FF\<cdot>\<bottom>\<cdot>k = \<bottom>"
 by (simp_all add: match_FF_def)
 
 subsection {* Mutual recursion *}
--- a/src/HOLCF/Tools/domain/domain_axioms.ML	Fri Apr 10 17:39:53 2009 -0700
+++ b/src/HOLCF/Tools/domain/domain_axioms.ML	Sat Apr 11 08:44:41 2009 -0700
@@ -60,14 +60,18 @@
 			   (if con'=con then TT else FF) args)) cons))
 	in map ddef cons end;
 
-  val mat_defs = let
-	fun mdef (con,_) = (mat_name con ^"_def",%%:(mat_name con) == 
-		 list_ccomb(%%:(dname^"_when"),map 
-			(fn (con',args) => (List.foldr /\#
-			   (if con'=con
-                               then mk_return (mk_ctuple (map (bound_arg args) args))
-                               else mk_fail) args)) cons))
-	in map mdef cons end;
+  val mat_defs =
+    let
+      fun mdef (con,_) =
+        let
+          val k = Bound 0
+          val x = Bound 1
+          fun one_con (con', args') =
+	    if con'=con then k else List.foldr /\# mk_fail args'
+          val w = list_ccomb(%%:(dname^"_when"), map one_con cons)
+          val rhs = /\ "x" (/\ "k" (w ` x))
+        in (mat_name con ^"_def", %%:(mat_name con) == rhs) end
+    in map mdef cons end;
 
   val pat_defs =
     let
--- a/src/HOLCF/Tools/domain/domain_syntax.ML	Fri Apr 10 17:39:53 2009 -0700
+++ b/src/HOLCF/Tools/domain/domain_syntax.ML	Sat Apr 11 08:44:41 2009 -0700
@@ -46,12 +46,14 @@
 			(* strictly speaking, these constants have one argument,
 			   but the mixfix (without arguments) is introduced only
 			   to generate parse rules for non-alphanumeric names*)
-  fun mat (con ,s,args) = (mat_name_ con, dtype->>mk_maybeT(mk_ctupleT(map third args)),
+  fun freetvar s n      = let val tvar = mk_TFree (s ^ string_of_int n) in
+			  if tvar mem typevars then freetvar ("t"^s) n else tvar end;
+  fun mk_matT (a,bs,c)  = a ->> foldr (op ->>) (mk_maybeT c) bs ->> mk_maybeT c;
+  fun mat (con,s,args)  = (mat_name_ con,
+                           mk_matT(dtype, map third args, freetvar "t" 1),
 			   Mixfix(escape ("match_" ^ con), [], Syntax.max_pri));
   fun sel1 (_,sel,typ)  = Option.map (fn s => (s,dtype ->> typ,NoSyn)) sel;
   fun sel (_   ,_,args) = List.mapPartial sel1 args;
-  fun freetvar s n      = let val tvar = mk_TFree (s ^ string_of_int n) in
-			  if tvar mem typevars then freetvar ("t"^s) n else tvar end;
   fun mk_patT (a,b)     = a ->> mk_maybeT b;
   fun pat_arg_typ n arg = mk_patT (third arg, freetvar "t" n);
   fun pat (con ,s,args) = (pat_name_ con, (mapn pat_arg_typ 1 args) --->
--- a/src/HOLCF/Tools/domain/domain_theorems.ML	Fri Apr 10 17:39:53 2009 -0700
+++ b/src/HOLCF/Tools/domain/domain_theorems.ML	Sat Apr 11 08:44:41 2009 -0700
@@ -314,8 +314,8 @@
 local
   fun mat_strict (con, _) =
     let
-      val goal = mk_trp (strict (%%:(mat_name con)));
-      val tacs = [rtac when_strict 1];
+      val goal = mk_trp (%%:(mat_name con) ` UU ` %:"rhs" === UU);
+      val tacs = [asm_simp_tac (HOLCF_ss addsimps [when_strict]) 1];
     in pg axs_mat_def goal (K tacs) end;
 
   val _ = trace " Proving mat_stricts...";
@@ -323,10 +323,10 @@
 
   fun one_mat c (con, args) =
     let
-      val lhs = %%:(mat_name c) ` con_app con args;
+      val lhs = %%:(mat_name c) ` con_app con args ` %:"rhs";
       val rhs =
         if con = c
-        then mk_return (mk_ctuple (map %# args))
+        then list_ccomb (%:"rhs", map %# args)
         else mk_fail;
       val goal = lift_defined %: (nonlazy args, mk_trp (lhs === rhs));
       val tacs = [asm_simp_tac (HOLCF_ss addsimps when_rews) 1];
--- a/src/HOLCF/Tools/fixrec_package.ML	Fri Apr 10 17:39:53 2009 -0700
+++ b/src/HOLCF/Tools/fixrec_package.ML	Sat Apr 11 08:44:41 2009 -0700
@@ -36,6 +36,8 @@
 
 infixr 6 ->>; val (op ->>) = cfunT;
 
+fun cfunsT (Ts, U) = foldr cfunT U Ts;
+
 fun dest_cfunT (Type(@{type_name "->"}, [T, U])) = (T, U)
   | dest_cfunT T = raise TYPE ("dest_cfunT", [T], []);
 
@@ -57,7 +59,9 @@
   | tupleT [T] = T
   | tupleT (T :: Ts) = HOLogic.mk_prodT (T, tupleT Ts);
 
-fun matchT T = body_cfun T ->> maybeT (tupleT (binder_cfun T));
+fun matchT (T, U) =
+  body_cfun T ->> cfunsT (binder_cfun T, U) ->> U;
+
 
 (*************************************************************************)
 (***************************** building terms ****************************)
@@ -240,10 +244,10 @@
         fun result_type (Type(@{type_name "->"},[_,T])) (x::xs) = result_type T xs
           | result_type T _ = T;
         val v = Free(n, result_type T vs);
-        val m = Const(match_name c, matchT T);
-        val k = lambda_ctuple vs rhs;
+        val m = Const(match_name c, matchT (T, fastype_of rhs));
+        val k = big_lambdas vs rhs;
       in
-        (mk_bind (m`v, k), v, n::taken)
+        (m`v`k, v, n::taken)
       end
   | Free(n,_) => fixrec_err ("expected constructor, found free variable " ^ quote n)
   | _ => fixrec_err "pre_build: invalid pattern";