--- a/src/HOLCF/Tools/fixrec_package.ML Thu Feb 26 08:48:33 2009 -0800
+++ b/src/HOLCF/Tools/fixrec_package.ML Thu Feb 26 10:28:53 2009 -0800
@@ -12,6 +12,8 @@
val add_fixrec_i: bool -> ((binding * attribute list) * term) list list -> theory -> theory
val add_fixpat: Attrib.binding * string list -> theory -> theory
val add_fixpat_i: (binding * attribute list) * term list -> theory -> theory
+ val add_matchers: (string * string) list -> theory -> theory
+ val setup: theory -> theory
end;
structure FixrecPackage: FIXREC_PACKAGE =
@@ -123,6 +125,17 @@
(*********** monadic notation and pattern matching compilation ***********)
(*************************************************************************)
+structure FixrecMatchData = TheoryDataFun (
+ type T = string Symtab.table;
+ val empty = Symtab.empty;
+ val copy = I;
+ val extend = I;
+ fun merge _ tabs : T = Symtab.merge (K true) tabs;
+);
+
+(* associate match functions with pattern constants *)
+fun add_matchers ms = FixrecMatchData.map (fold Symtab.update ms);
+
fun add_names (Const(a,_), bs) = insert (op =) (Sign.base_name a) bs
| add_names (Free(a,_) , bs) = insert (op =) a bs
| add_names (f $ u , bs) = add_names (f, add_names(u, bs))
@@ -132,20 +145,20 @@
fun add_terms ts xs = foldr add_names xs ts;
(* builds a monadic term for matching a constructor pattern *)
-fun pre_build pat rhs vs taken =
+fun pre_build match_name pat rhs vs taken =
case pat of
Const(@{const_name Rep_CFun},_)$f$(v as Free(n,T)) =>
- pre_build f rhs (v::vs) taken
+ pre_build match_name f rhs (v::vs) taken
| Const(@{const_name Rep_CFun},_)$f$x =>
- let val (rhs', v, taken') = pre_build x rhs [] taken;
- in pre_build f rhs' (v::vs) taken' end
+ let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
+ in pre_build match_name f rhs' (v::vs) taken' end
| Const(c,T) =>
let
val n = Name.variant taken "v";
fun result_type (Type(@{type_name "->"},[_,T])) (x::xs) = result_type T xs
| result_type T _ = T;
val v = Free(n, result_type T vs);
- val m = "match_"^(extern_name(Sign.base_name c));
+ val m = match_name c;
val k = lambda_ctuple vs rhs;
in
(%%:@{const_name Fixrec.bind}`(%%:m`v)`k, v, n::taken)
@@ -155,19 +168,22 @@
(* builds a monadic term for matching a function definition pattern *)
(* returns (name, arity, matcher) *)
-fun building pat rhs vs taken =
+fun building match_name pat rhs vs taken =
case pat of
Const(@{const_name Rep_CFun}, _)$f$(v as Free(n,T)) =>
- building f rhs (v::vs) taken
+ building match_name f rhs (v::vs) taken
| Const(@{const_name Rep_CFun}, _)$f$x =>
- let val (rhs', v, taken') = pre_build x rhs [] taken;
- in building f rhs' (v::vs) taken' end
+ let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
+ in building match_name f rhs' (v::vs) taken' end
| Const(name,_) => (name, length vs, big_lambdas vs rhs)
| _ => fixrec_err "function is not declared as constant in theory";
-fun match_eq eq =
+fun match_eq match_name eq =
let val (lhs,rhs) = dest_eqs eq;
- in building lhs (%%:@{const_name Fixrec.return}`rhs) [] (add_terms [eq] []) end;
+ in
+ building match_name lhs (%%:@{const_name Fixrec.return}`rhs) []
+ (add_terms [eq] [])
+ end;
(* returns the sum (using +++) of the terms in ms *)
(* also applies "run" to the result! *)
@@ -190,9 +206,9 @@
in (x::xs, y::ys, z::zs) end;
(* this is the pattern-matching compiler function *)
-fun compile_pats eqs =
+fun compile_pats match_name eqs =
let
- val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
+ val ((n::names),(a::arities),mats) = unzip3 (map (match_eq match_name) eqs);
val cname = if forall (fn x => n=x) names then n
else fixrec_err "all equations in block must define the same function";
val arity = if forall (fn x => a=x) arities then a
@@ -235,8 +251,14 @@
fun unconcat [] _ = []
| unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));
+ val matcher_tab = FixrecMatchData.get thy;
+ fun match_name c =
+ case Symtab.lookup matcher_tab c of SOME m => m
+ | NONE => fixrec_err ("unknown pattern constructor: " ^ c);
+
val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts');
- val compiled_ts = map (legacy_infer_term thy o compile_pats) pattern_blocks;
+ val compiled_ts =
+ map (legacy_infer_term thy o compile_pats match_name) pattern_blocks;
val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy;
in
if strict then let (* only prove simp rules if strict = true *)
@@ -312,4 +334,6 @@
end; (* local structure *)
+val setup = FixrecMatchData.init;
+
end;