src/HOLCF/Tools/fixrec_package.ML
changeset 30131 6be1be402ef0
parent 29585 c23295521af5
child 30132 243a05a67c41
--- a/src/HOLCF/Tools/fixrec_package.ML	Thu Feb 26 08:48:33 2009 -0800
+++ b/src/HOLCF/Tools/fixrec_package.ML	Thu Feb 26 10:28:53 2009 -0800
@@ -12,6 +12,8 @@
   val add_fixrec_i: bool -> ((binding * attribute list) * term) list list -> theory -> theory
   val add_fixpat: Attrib.binding * string list -> theory -> theory
   val add_fixpat_i: (binding * attribute list) * term list -> theory -> theory
+  val add_matchers: (string * string) list -> theory -> theory
+  val setup: theory -> theory
 end;
 
 structure FixrecPackage: FIXREC_PACKAGE =
@@ -123,6 +125,17 @@
 (*********** monadic notation and pattern matching compilation ***********)
 (*************************************************************************)
 
+structure FixrecMatchData = TheoryDataFun (
+  type T = string Symtab.table;
+  val empty = Symtab.empty;
+  val copy = I;
+  val extend = I;
+  fun merge _ tabs : T = Symtab.merge (K true) tabs;
+);
+
+(* associate match functions with pattern constants *)
+fun add_matchers ms = FixrecMatchData.map (fold Symtab.update ms);
+
 fun add_names (Const(a,_), bs) = insert (op =) (Sign.base_name a) bs
   | add_names (Free(a,_) , bs) = insert (op =) a bs
   | add_names (f $ u     , bs) = add_names (f, add_names(u, bs))
@@ -132,20 +145,20 @@
 fun add_terms ts xs = foldr add_names xs ts;
 
 (* builds a monadic term for matching a constructor pattern *)
-fun pre_build pat rhs vs taken =
+fun pre_build match_name pat rhs vs taken =
   case pat of
     Const(@{const_name Rep_CFun},_)$f$(v as Free(n,T)) =>
-      pre_build f rhs (v::vs) taken
+      pre_build match_name f rhs (v::vs) taken
   | Const(@{const_name Rep_CFun},_)$f$x =>
-      let val (rhs', v, taken') = pre_build x rhs [] taken;
-      in pre_build f rhs' (v::vs) taken' end
+      let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
+      in pre_build match_name f rhs' (v::vs) taken' end
   | Const(c,T) =>
       let
         val n = Name.variant taken "v";
         fun result_type (Type(@{type_name "->"},[_,T])) (x::xs) = result_type T xs
           | result_type T _ = T;
         val v = Free(n, result_type T vs);
-        val m = "match_"^(extern_name(Sign.base_name c));
+        val m = match_name c;
         val k = lambda_ctuple vs rhs;
       in
         (%%:@{const_name Fixrec.bind}`(%%:m`v)`k, v, n::taken)
@@ -155,19 +168,22 @@
 
 (* builds a monadic term for matching a function definition pattern *)
 (* returns (name, arity, matcher) *)
-fun building pat rhs vs taken =
+fun building match_name pat rhs vs taken =
   case pat of
     Const(@{const_name Rep_CFun}, _)$f$(v as Free(n,T)) =>
-      building f rhs (v::vs) taken
+      building match_name f rhs (v::vs) taken
   | Const(@{const_name Rep_CFun}, _)$f$x =>
-      let val (rhs', v, taken') = pre_build x rhs [] taken;
-      in building f rhs' (v::vs) taken' end
+      let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
+      in building match_name f rhs' (v::vs) taken' end
   | Const(name,_) => (name, length vs, big_lambdas vs rhs)
   | _ => fixrec_err "function is not declared as constant in theory";
 
-fun match_eq eq = 
+fun match_eq match_name eq = 
   let val (lhs,rhs) = dest_eqs eq;
-  in building lhs (%%:@{const_name Fixrec.return}`rhs) [] (add_terms [eq] []) end;
+  in
+    building match_name lhs (%%:@{const_name Fixrec.return}`rhs) []
+      (add_terms [eq] [])
+  end;
 
 (* returns the sum (using +++) of the terms in ms *)
 (* also applies "run" to the result! *)
@@ -190,9 +206,9 @@
       in (x::xs, y::ys, z::zs) end;
 
 (* this is the pattern-matching compiler function *)
-fun compile_pats eqs = 
+fun compile_pats match_name eqs = 
   let
-    val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
+    val ((n::names),(a::arities),mats) = unzip3 (map (match_eq match_name) eqs);
     val cname = if forall (fn x => n=x) names then n
           else fixrec_err "all equations in block must define the same function";
     val arity = if forall (fn x => a=x) arities then a
@@ -235,8 +251,14 @@
     
     fun unconcat [] _ = []
       | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));
+    val matcher_tab = FixrecMatchData.get thy;
+    fun match_name c =
+          case Symtab.lookup matcher_tab c of SOME m => m
+            | NONE => fixrec_err ("unknown pattern constructor: " ^ c);
+
     val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts');
-    val compiled_ts = map (legacy_infer_term thy o compile_pats) pattern_blocks;
+    val compiled_ts =
+          map (legacy_infer_term thy o compile_pats match_name) pattern_blocks;
     val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy;
   in
     if strict then let (* only prove simp rules if strict = true *)
@@ -312,4 +334,6 @@
 
 end; (* local structure *)
 
+val setup = FixrecMatchData.init;
+
 end;