generate pattern combinators for new datatypes
authorhuffman
Tue, 08 Nov 2005 02:19:11 +0100
changeset 18113 fb76eea85835
parent 18112 dc1d6f588204
child 18114 a36a9b2921e9
generate pattern combinators for new datatypes
src/HOLCF/domain/axioms.ML
src/HOLCF/domain/library.ML
src/HOLCF/domain/syntax.ML
src/HOLCF/domain/theorems.ML
--- a/src/HOLCF/domain/axioms.ML	Mon Nov 07 23:33:01 2005 +0100
+++ b/src/HOLCF/domain/axioms.ML	Tue Nov 08 02:19:11 2005 +0100
@@ -72,6 +72,21 @@
                                else %%:failN) args)) cons))
 	in map mdef cons end;
 
+  val pat_defs =
+    let
+      fun pdef (con,args) =
+        let
+          val ps = mapn (fn n => fn _ => %:("pat" ^ string_of_int n)) 1 args;
+          val xs = map (bound_arg args) args;
+          val r = Bound (length args);
+          val rhs = case args of [] => %%:returnN ` r
+                                | _ => foldr1 cpair_pat ps ` r ` mk_ctuple xs;
+          fun one_con (con',args') = foldr /\# (if con'=con then rhs else %%:failN) args';
+        in (pat_name con ^"_def", list_comb (%%:(pat_name con), ps) == 
+               /\ "r" (list_ccomb(%%:(dname^"_when"), map one_con cons)))
+        end
+    in map pdef cons end;
+  
   val sel_defs = let
 	fun sdef con n arg = Option.map (fn sel => (sel^"_def",%%:sel == 
 		 list_ccomb(%%:(dname^"_when"),map 
@@ -92,7 +107,7 @@
 in (dnam,
     [abs_iso_ax, rep_iso_ax, reach_ax],
     [when_def, copy_def] @
-     con_defs @ dis_defs @ mat_defs @ sel_defs @
+     con_defs @ dis_defs @ mat_defs @ pat_defs @ sel_defs @
     [take_def, finite_def])
 end; (* let *)
 
--- a/src/HOLCF/domain/library.ML	Mon Nov 07 23:33:01 2005 +0100
+++ b/src/HOLCF/domain/library.ML	Tue Nov 08 02:19:11 2005 +0100
@@ -55,6 +55,8 @@
 fun dis_name_ con = "is_"^ (strip_esc   con);
 fun mat_name  con = "match_"^ (extern_name con);
 fun mat_name_ con = "match_"^ (strip_esc   con);
+fun pat_name  con = (extern_name con) ^ "_pat";
+fun pat_name_ con = (strip_esc   con) ^ "_pat";
 
 (* make distinct names out of the type list, 
    forbidding "o","n..","x..","f..","P.." as names *)
@@ -127,6 +129,7 @@
 val fixN       = "Fix.fix";
 val returnN    = "Fixrec.return";
 val failN      = "Fixrec.fail";
+val cpair_patN = "Fixrec.cpair_pat";
 
 val pcpoN      = "Pcpo.pcpo"
 val pcpoS      = [pcpoN];
@@ -203,6 +206,7 @@
 |   mk_stuple ts = foldr1 spair ts;
 fun mk_ctupleT [] = HOLogic.unitT   (* used in match_defs *)
 |   mk_ctupleT Ts = foldr1 HOLogic.mk_prodT Ts;
+fun cpair_pat (p1,p2) = %%:cpair_patN $ p1 $ p2;
 fun lift_defined f = lift (fn x => defined (f x));
 fun bound_arg vns v = Bound(length vns -find_index_eq v vns -1);
 
--- a/src/HOLCF/domain/syntax.ML	Mon Nov 07 23:33:01 2005 +0100
+++ b/src/HOLCF/domain/syntax.ML	Tue Nov 08 02:19:11 2005 +0100
@@ -51,10 +51,19 @@
 			   Mixfix(escape ("match_" ^ con), [], Syntax.max_pri));
   fun sel1 (_,sel,typ)  = Option.map (fn s => (s,dtype ->> typ,NoSyn)) sel;
   fun sel (_   ,_,args) = List.mapPartial sel1 args;
+  fun freetvar s n      = let val tvar = mk_TFree (s ^ string_of_int n) in
+			  if tvar mem typevars then freetvar ("t"^s) n else tvar end;
+  fun mk_patT (a,b,c)   = b ->> a ->> mk_ssumT(oneT, mk_uT c);
+  fun pat_arg_typ n arg = mk_patT (third arg, freetvar "t" n, freetvar "t" (n+1));
+  fun pat (con ,s,args) = (pat_name_ con, (mapn pat_arg_typ 1 args) --->
+			   mk_patT (dtype, freetvar "t" 1, freetvar "t" (length args + 1)),
+			   Mixfix(escape (con ^ "_pat"), [], Syntax.max_pri));
+
 in
   val consts_con = map con cons';
   val consts_dis = map dis cons';
   val consts_mat = map mat cons';
+  val consts_pat = map pat cons';
   val consts_sel = List.concat(map sel cons');
 end;
 
@@ -78,6 +87,9 @@
     fun case1 n (con,mx,args) = app "_case1" (con1 n (con,mx,args), expvar n);
     fun arg1 n (con,_,args) = foldr cabs (expvar n) (mapn (argvar n) 1 args);
     fun when1 n m = if n = m then arg1 n else K (Constant "UU");
+
+    fun app_var x = mk_appl (Constant "_var") [x];
+    fun app_pat x = mk_appl (Constant "_pat") [x];
   in
     val case_trans = ParsePrintRule
         (app "_case_syntax" (Variable "x", foldr1 (app "_case2") (mapn case1 1 cons')),
@@ -86,13 +98,21 @@
     val abscon_trans = mapn (fn n => fn (con,mx,args) => ParsePrintRule
         (cabs (con1 n (con,mx,args), expvar n),
          Library.foldl capp (Constant (dnam^"_when"), mapn (when1 n) 1 cons'))) 1 cons';
+    
+    val pattern_trans = mapn (fn n => fn (con,mx,args) => ParseRule
+          (app_var (Library.foldl capp (c_ast con mx, mapn (argvar n) 1 args)),
+           mk_appl (Constant (pat_name_ con)) (map app_var (mapn (argvar n) 1 args)))) 1 cons';
+    
+    val pattern_trans' = mapn (fn n => fn (con,mx,args) => PrintRule
+          (Library.foldl capp (c_ast con mx, map app_pat (mapn (argvar n) 1 args)),
+           app_pat (mk_appl (Constant (pat_name_ con)) (mapn (argvar n) 1 args)))) 1 cons';
   end;
 end;
 
 in ([const_rep, const_abs, const_when, const_copy] @ 
-     consts_con @ consts_dis @ consts_mat @ consts_sel @
+     consts_con @ consts_dis @ consts_mat @ consts_pat @ consts_sel @
     [const_take, const_finite],
-    (case_trans::abscon_trans))
+    (case_trans::(abscon_trans @ pattern_trans @ pattern_trans')))
 end; (* let *)
 
 (* ----- putting all the syntax stuff together ------------------------------ *)
--- a/src/HOLCF/domain/theorems.ML	Mon Nov 07 23:33:01 2005 +0100
+++ b/src/HOLCF/domain/theorems.ML	Tue Nov 08 02:19:11 2005 +0100
@@ -74,6 +74,7 @@
 val axs_con_def   = map (fn (con,_) => ga (extern_name con^"_def") dname) cons;
 val axs_dis_def   = map (fn (con,_) => ga (   dis_name con^"_def") dname) cons;
 val axs_mat_def   = map (fn (con,_) => ga (   mat_name con^"_def") dname) cons;
+val axs_pat_def   = map (fn (con,_) => ga (   pat_name con^"_def") dname) cons;
 val axs_sel_def   = List.concat(map (fn (_,args) => List.mapPartial (fn arg =>
                  Option.map (fn sel => ga (sel^"_def") dname) (sel_of arg)) args)
 									  cons);
@@ -206,6 +207,22 @@
         in List.concat(map (fn (c,_) => map (one_mat c) cons) cons) end;
 in mat_stricts @ mat_apps end;
 
+val pat_rews = let
+  fun ps args = mapn (fn n => fn _ => %:("pat" ^ string_of_int n)) 1 args;
+  fun pat_lhs (con,args) = list_comb (%%:(pat_name con), ps args);
+  fun pat_rhs (con,[]) = %%:returnN ` (%:"rhs")
+  |   pat_rhs (con,args) = (foldr1 cpair_pat (ps args))`(%:"rhs")`(mk_ctuple (map %# args));
+  val pat_stricts = map (fn (con,args) => pg axs_pat_def (mk_trp(
+                             strict(pat_lhs (con,args)`(%:"rhs"))))
+                   [simp_tac (HOLCF_ss addsimps [when_strict]) 1]) cons;
+  val pat_apps = let fun one_pat c (con,args)= pg axs_pat_def
+                   (lift_defined %: (nonlazy args,
+                        (mk_trp((pat_lhs c)`(%:"rhs")`(con_app con args) ===
+                              (if con = fst c then pat_rhs c else %%:failN)))))
+                   [asm_simp_tac (HOLCF_ss addsimps when_rews) 1];
+        in List.concat(map (fn c => map (one_pat c) cons) cons) end;
+in pat_stricts @ pat_apps end;
+
 val con_stricts = List.concat(map (fn (con,args) => map (fn vn =>
                         pg con_appls
                            (mk_trp(con_app2 con (fn arg => if vname arg = vn 
@@ -344,6 +361,7 @@
 		("con_rews", con_rews),
 		("sel_rews", sel_rews),
 		("dis_rews", dis_rews),
+		("pat_rews", pat_rews),
 		("dist_les", dist_les),
 		("dist_eqs", dist_eqs),
 		("inverts" , inverts ),