use TheoryData to keep track of pattern match combinators
authorhuffman
Thu, 26 Feb 2009 10:28:53 -0800
changeset 30131 6be1be402ef0
parent 30130 e23770bc97c8
child 30132 243a05a67c41
use TheoryData to keep track of pattern match combinators
src/HOLCF/Fixrec.thy
src/HOLCF/Tools/domain/domain_axioms.ML
src/HOLCF/Tools/fixrec_package.ML
--- a/src/HOLCF/Fixrec.thy	Thu Feb 26 08:48:33 2009 -0800
+++ b/src/HOLCF/Fixrec.thy	Thu Feb 26 10:28:53 2009 -0800
@@ -583,6 +583,20 @@
 
 use "Tools/fixrec_package.ML"
 
+setup {* FixrecPackage.setup *}
+
+setup {*
+  FixrecPackage.add_matchers
+    [ (@{const_name up}, @{const_name match_up}),
+      (@{const_name sinl}, @{const_name match_sinl}),
+      (@{const_name sinr}, @{const_name match_sinr}),
+      (@{const_name spair}, @{const_name match_spair}),
+      (@{const_name cpair}, @{const_name match_cpair}),
+      (@{const_name ONE}, @{const_name match_ONE}),
+      (@{const_name TT}, @{const_name match_TT}),
+      (@{const_name FF}, @{const_name match_FF}) ]
+*}
+
 hide (open) const return bind fail run cases
 
 end
--- a/src/HOLCF/Tools/domain/domain_axioms.ML	Thu Feb 26 08:48:33 2009 -0800
+++ b/src/HOLCF/Tools/domain/domain_axioms.ML	Thu Feb 26 10:28:53 2009 -0800
@@ -39,7 +39,7 @@
     fun one_con (con,args) =
         foldr /\# (list_ccomb (%%:con, mapn (idxs (length args)) 1 args)) args;
   in ("copy_def", %%:(dname^"_copy") ==
-       /\"f" (list_ccomb (%%:(dname^"_when"), map one_con cons))) end;
+       /\ "f" (list_ccomb (%%:(dname^"_when"), map one_con cons))) end;
 
 (* -- definitions concerning the constructors, discriminators and selectors - *)
 
@@ -107,7 +107,7 @@
     [when_def, copy_def] @
      con_defs @ dis_defs @ mat_defs @ pat_defs @ sel_defs @
     [take_def, finite_def])
-end; (* let *)
+end; (* let (calc_axioms) *)
 
 fun infer_props thy = map (apsnd (FixrecPackage.legacy_infer_prop thy));
 
@@ -117,6 +117,14 @@
 fun add_defs_i x = snd o (PureThy.add_defs false) (map (Thm.no_attributes o apfst Binding.name) x);
 fun add_defs_infer defs thy = add_defs_i (infer_props thy defs) thy;
 
+fun add_matchers (((dname,_),cons) : eq) thy =
+  let
+    val con_names = map fst cons;
+    val mat_names = map mat_name con_names;
+    fun qualify n = Sign.full_name thy (Binding.name n);
+    val ms = map qualify con_names ~~ map qualify mat_names;
+  in FixrecPackage.add_matchers ms thy end;
+
 in (* local *)
 
 fun add_axioms (comp_dnam, eqs : eq list) thy' = let
@@ -125,7 +133,7 @@
   val x_name = idx_name dnames "x"; 
   fun copy_app dname = %%:(dname^"_copy")`Bound 0;
   val copy_def = ("copy_def" , %%:(comp_dname^"_copy") ==
-				    /\"f"(mk_ctuple (map copy_app dnames)));
+				    /\ "f"(mk_ctuple (map copy_app dnames)));
   val bisim_def = ("bisim_def",%%:(comp_dname^"_bisim")==mk_lam("R",
     let
       fun one_con (con,args) = let
@@ -164,7 +172,8 @@
 in thy |> Sign.add_path comp_dnam  
        |> add_defs_infer (bisim_def::(if length eqs>1 then [copy_def] else []))
        |> Sign.parent_path
-end;
+       |> fold add_matchers eqs
+end; (* let (add_axioms) *)
 
 end; (* local *)
 end; (* struct *)
--- a/src/HOLCF/Tools/fixrec_package.ML	Thu Feb 26 08:48:33 2009 -0800
+++ b/src/HOLCF/Tools/fixrec_package.ML	Thu Feb 26 10:28:53 2009 -0800
@@ -12,6 +12,8 @@
   val add_fixrec_i: bool -> ((binding * attribute list) * term) list list -> theory -> theory
   val add_fixpat: Attrib.binding * string list -> theory -> theory
   val add_fixpat_i: (binding * attribute list) * term list -> theory -> theory
+  val add_matchers: (string * string) list -> theory -> theory
+  val setup: theory -> theory
 end;
 
 structure FixrecPackage: FIXREC_PACKAGE =
@@ -123,6 +125,17 @@
 (*********** monadic notation and pattern matching compilation ***********)
 (*************************************************************************)
 
+structure FixrecMatchData = TheoryDataFun (
+  type T = string Symtab.table;
+  val empty = Symtab.empty;
+  val copy = I;
+  val extend = I;
+  fun merge _ tabs : T = Symtab.merge (K true) tabs;
+);
+
+(* associate match functions with pattern constants *)
+fun add_matchers ms = FixrecMatchData.map (fold Symtab.update ms);
+
 fun add_names (Const(a,_), bs) = insert (op =) (Sign.base_name a) bs
   | add_names (Free(a,_) , bs) = insert (op =) a bs
   | add_names (f $ u     , bs) = add_names (f, add_names(u, bs))
@@ -132,20 +145,20 @@
 fun add_terms ts xs = foldr add_names xs ts;
 
 (* builds a monadic term for matching a constructor pattern *)
-fun pre_build pat rhs vs taken =
+fun pre_build match_name pat rhs vs taken =
   case pat of
     Const(@{const_name Rep_CFun},_)$f$(v as Free(n,T)) =>
-      pre_build f rhs (v::vs) taken
+      pre_build match_name f rhs (v::vs) taken
   | Const(@{const_name Rep_CFun},_)$f$x =>
-      let val (rhs', v, taken') = pre_build x rhs [] taken;
-      in pre_build f rhs' (v::vs) taken' end
+      let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
+      in pre_build match_name f rhs' (v::vs) taken' end
   | Const(c,T) =>
       let
         val n = Name.variant taken "v";
         fun result_type (Type(@{type_name "->"},[_,T])) (x::xs) = result_type T xs
           | result_type T _ = T;
         val v = Free(n, result_type T vs);
-        val m = "match_"^(extern_name(Sign.base_name c));
+        val m = match_name c;
         val k = lambda_ctuple vs rhs;
       in
         (%%:@{const_name Fixrec.bind}`(%%:m`v)`k, v, n::taken)
@@ -155,19 +168,22 @@
 
 (* builds a monadic term for matching a function definition pattern *)
 (* returns (name, arity, matcher) *)
-fun building pat rhs vs taken =
+fun building match_name pat rhs vs taken =
   case pat of
     Const(@{const_name Rep_CFun}, _)$f$(v as Free(n,T)) =>
-      building f rhs (v::vs) taken
+      building match_name f rhs (v::vs) taken
   | Const(@{const_name Rep_CFun}, _)$f$x =>
-      let val (rhs', v, taken') = pre_build x rhs [] taken;
-      in building f rhs' (v::vs) taken' end
+      let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
+      in building match_name f rhs' (v::vs) taken' end
   | Const(name,_) => (name, length vs, big_lambdas vs rhs)
   | _ => fixrec_err "function is not declared as constant in theory";
 
-fun match_eq eq = 
+fun match_eq match_name eq = 
   let val (lhs,rhs) = dest_eqs eq;
-  in building lhs (%%:@{const_name Fixrec.return}`rhs) [] (add_terms [eq] []) end;
+  in
+    building match_name lhs (%%:@{const_name Fixrec.return}`rhs) []
+      (add_terms [eq] [])
+  end;
 
 (* returns the sum (using +++) of the terms in ms *)
 (* also applies "run" to the result! *)
@@ -190,9 +206,9 @@
       in (x::xs, y::ys, z::zs) end;
 
 (* this is the pattern-matching compiler function *)
-fun compile_pats eqs = 
+fun compile_pats match_name eqs = 
   let
-    val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
+    val ((n::names),(a::arities),mats) = unzip3 (map (match_eq match_name) eqs);
     val cname = if forall (fn x => n=x) names then n
           else fixrec_err "all equations in block must define the same function";
     val arity = if forall (fn x => a=x) arities then a
@@ -235,8 +251,14 @@
     
     fun unconcat [] _ = []
       | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));
+    val matcher_tab = FixrecMatchData.get thy;
+    fun match_name c =
+          case Symtab.lookup matcher_tab c of SOME m => m
+            | NONE => fixrec_err ("unknown pattern constructor: " ^ c);
+
     val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts');
-    val compiled_ts = map (legacy_infer_term thy o compile_pats) pattern_blocks;
+    val compiled_ts =
+          map (legacy_infer_term thy o compile_pats match_name) pattern_blocks;
     val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy;
   in
     if strict then let (* only prove simp rules if strict = true *)
@@ -312,4 +334,6 @@
 
 end; (* local structure *)
 
+val setup = FixrecMatchData.init;
+
 end;