cleaned up and reorganized
authorhuffman
Tue, 14 Jun 2005 04:04:09 +0200
changeset 16387 67f6044c1891
parent 16386 c6f5ade29608
child 16388 1ff571813848
cleaned up and reorganized
src/HOLCF/fixrec_package.ML
--- a/src/HOLCF/fixrec_package.ML	Tue Jun 14 03:50:20 2005 +0200
+++ b/src/HOLCF/fixrec_package.ML	Tue Jun 14 04:04:09 2005 +0200
@@ -27,19 +27,25 @@
 		   ("o"::"p"::" "::rest) => implode rest
 		   | _ => con;
 
-(***************  This is the building functional **************)
-
-(* converts string proposition to a cterm *)
-fun all eqs thy = (let val sign = sign_of (thy) in 
-                  ((term_of o (Thm.read_cterm sign) o rpair propT) eqs) end);
-
 (* splits a cterm into the right and lefthand sides of equality *)
 fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
   | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs))    = (lhs,rhs)
   | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t);
 
-(* building fixpoint functional def for an equation with only 
-   variables as parameters *)
+(* these are helpful functions copied from HOLCF/domain/library.ML *)
+fun %: s = Free(s,dummyT);
+fun %%: s = Const(s,dummyT);
+infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
+infix 1 ===; fun S === T = %%:"op =" $ S $ T;
+infix 9 `  ; fun f ` x = %%:"Rep_CFun" $ f $ x;
+
+(* infers the type of a term *)
+fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
+
+(*************************************************************************)
+(************ fixed-point definitions and unfolding theorems *************)
+(*************************************************************************)
+
 fun func1 (lhs as Const(name,T), rhs) =
   let
     val basename = Sign.base_name name;
@@ -55,27 +61,17 @@
   end
   | func1 t = sys_error "func1: not of correct form";
 
-(**************************************************************)
-(* these are helpful functions copied from HOLCF/domain/library.ML *)
-fun %: s = Free(s,dummyT);
-fun %%: s = Const(s,dummyT);
-infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
-infix 1 ===; fun S === T = %%:"op =" $ S $ T;
-infix 9 `  ; fun f`  x = %%:"Rep_CFun" $ f $ x;
-
-(* infers the type of a term *)
-fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
-
-(*****************************************************************)
-(* monadic notation and pattern matching *)
-(*****************************************************************)
+(*************************************************************************)
+(*********** monadic notation and pattern matching compilation ***********)
+(*************************************************************************)
 
 (* these 3 functions strip off parameters and destruct constructors *)
+(*
 fun strip_cpair (Const("Cfun.Rep_CFun",_) $
       (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
         b :: strip_cpair r
   | strip_cpair c = [c];
-
+*)
 fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
 
 fun big_lambdas [] rhs = rhs
@@ -95,46 +91,43 @@
 fun add_terms ts xs = foldr add_names xs ts;
 
 (* builds a monadic term for matching a constructor pattern *)
-fun pre_build (Const("Cfun.Rep_CFun",_)$f$(v as Free(n,T))) rhs vs taken =
+fun pre_build pat rhs vs taken =
+  case pat of
+    Const("Cfun.Rep_CFun",_)$f$(v as Free(n,T)) =>
       pre_build f rhs (v::vs) taken
-  | pre_build (Const("Cfun.Rep_CFun",_)$f$x) rhs vs taken =
+  | Const("Cfun.Rep_CFun",_)$f$x =>
       let val (rhs', v, taken') = pre_build x rhs [] taken;
-      in
-        pre_build f rhs' (v::vs) taken'
-      end
-  | pre_build (Const(c,T)) rhs vs taken =
+      in pre_build f rhs' (v::vs) taken' end
+  | Const(c,T) =>
       let
         val n = variant taken "v";
-        fun result_type (Type(_,[_,T])) (x::xs) = result_type T xs
+        fun result_type (Type("Cfun.->",[_,T])) (x::xs) = result_type T xs
           | result_type T _ = T;
         val v = Free(n, result_type T vs);
         val m = "match_"^(extern_name(NameSpace.base c));
         val k = lambda_tuple vs rhs;
       in
-        (%%:"bind"`(%%:m`v)`k, v, n::taken)
+        (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
       end;
 
 (* builds a monadic term for matching a function definition pattern *)
 (* returns (name, arity, matcher) *)
-fun building (Const("Cfun.Rep_CFun", _)$f$(v as Free(n,T))) rhs vs taken =
+fun building pat rhs vs taken =
+  case pat of
+    Const("Cfun.Rep_CFun", _)$f$(v as Free(n,T)) =>
       building f rhs (v::vs) taken
-  | building (Const("Cfun.Rep_CFun", _)$f$x) rhs vs taken =
-      let
-        val (rhs', v, taken') = pre_build x rhs [] taken;
-      in
-        building f rhs' (v::vs) taken'
-      end
-  | building (c as Const(_,_)) rhs vs taken = (c, length vs, big_lambdas vs rhs)
-  | building _ _ _ _ = sys_error "function is not declared as constant in theory";
+  | Const("Cfun.Rep_CFun", _)$f$x =>
+      let val (rhs', v, taken') = pre_build x rhs [] taken;
+      in building f rhs' (v::vs) taken' end
+  | Const(_,_) => (pat, length vs, big_lambdas vs rhs)
+  | _ => sys_error "function is not declared as constant in theory";
 
-fun match_eq thy f = 
+fun match_eq eq = 
   let
-    val e = (all f thy);
-    val (lhs,rhs) = dest_eqs e;
-    val (Const(n,_), a, t) = building lhs (%%:"return"`rhs) [] (add_terms [e] []);
-  in
-    (n, a, t)
-  end;
+    val (lhs,rhs) = dest_eqs eq;
+    val (Const(name,_), arity, term) =
+      building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []);
+  in (name, arity, term) end;
 
 (* returns the sum (using +++) of the terms in ms *)
 (* also applies "run" to the result! *)
@@ -151,14 +144,34 @@
     reLAM arity (%%:"Fixrec.run"`msum)
   end;
 
-(***************************************************************)
-(*** Proving associated theorems ***)
+fun unzip3 [] = ([],[],[])
+  | unzip3 ((x,y,z)::ts) =
+      let val (xs,ys,zs) = unzip3 ts
+      in (x::xs, y::ys, z::zs) end;
 
-fun prove_thm thy unfold_thm x =
+(* this is the pattern-matching compiler function *)
+fun compile_pats eqs = 
+  let
+    val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
+    val cname = if forall (fn x => n=x) names then n
+          else sys_error "FIXREC: all equations must define the same function";
+    val arity = if forall (fn x => a=x) arities then a
+          else sys_error "FIXREC: all equations must have the same arity";
+    val rhs = fatbar arity mats;
+  in
+    HOLogic.mk_Trueprop (%%:cname === rhs)
+  end;
+
+(*************************************************************************)
+(********************** Proving associated theorems **********************)
+(*************************************************************************)
+
+fun prove_thm thy unfold_thm ct =
   let
     val ss = simpset_of thy;
-    val thm = prove_goalw thy [] x (fn _ => [SOLVE(stac unfold_thm 1 THEN simp_tac ss 1)])
-      handle _ => sys_error (x^" :: proof failed on this equation.");
+    val thm = prove_goalw_cterm [] ct
+      (fn _ => [SOLVE(stac unfold_thm 1 THEN simp_tac ss 1)])
+        handle _ => sys_error (string_of_cterm ct^" :: proof failed on this equation.");
   in thm end;
 
 (* this proves that each equation is a theorem *)
@@ -184,64 +197,57 @@
     (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
   end;
 
-fun unzip3 [] = ([],[],[])
-  | unzip3 ((x,y,z)::ts) =
-      let val (xs,ys,zs) = unzip3 ts
-      in (x::xs, y::ys, z::zs) end;
+(*************************************************************************)
+(************************* Main fixrec function **************************)
+(*************************************************************************)
 
-(* this is the main processing function *)
-fun pat_fun eqs thy = 
+(* this calls the main processing function and then returns the new state *)
+fun add_fixrec strs thy =
   let
-    val ((n::names),(a::arities),mats) = unzip3 (map (match_eq thy) eqs);
-    val cname = if forall (fn x => n=x) names then n
-          else sys_error "PAT_FUN: all equations must define the same function";
-    val arity = if forall (fn x => a=x) arities then a
-          else sys_error "FIXREC: all equations must have the same arity";
-    val msum = fatbar arity mats;
-    val v = variant (add_term_names (msum,[])) "v";
-    val funcc = infer (HOLogic.mk_Trueprop (%%:cname === msum)) thy;
+    val sg = sign_of thy;
+    val cts = map (Thm.read_cterm sg o rpair propT) strs;
+    val eqs = map term_of cts;
+    val funcc = infer (compile_pats eqs) thy;
+    val _ = print_cterm (cterm_of sg funcc);
     val (name', fixdef_name_term) = func1 (dest_eqs funcc);
     val (thy', [fixdef_thm]) =
       PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
+    val ct = cterm_of (sign_of thy') funcc;
   in
-    make_simp name' eqs (cterm_of (sign_of thy') funcc) fixdef_thm thy'
+    make_simp name' cts ct fixdef_thm thy'
   end;
 
-(***************************************************************)
-
-(* this calls the main processing function and then returns the new state *)
-fun add_fixrec eqs = pat_fun eqs;
-
-(*****************************************************************)
-(*** Fixpat ***)
-
-(* like Term.strip_comb, but with continuous application *)
-fun strip_cdot u : term * term list =
-  let fun stripc (Const("Cfun.Rep_CFun",_)$f$t, ts) = stripc (f, t::ts)
-        | stripc x = x
-  in stripc(u,[]) end;
+(*************************************************************************)
+(******************************** Fixpat *********************************)
+(*************************************************************************)
 
 fun fix_pat name pat thy = 
   let
     val sign = sign_of thy;
-    val ct = Thm.read_cterm sign (pat, dummyT);
-    val (Const (f,_), args) = strip_cdot (term_of ct);
-    val unfold_thm = Goals.get_thm thy (f^"_unfold");
-    fun add_arg (arg,thm) = instantiate' [] [SOME arg] (thm RS cfun_fun_cong);
-    val unfold_thm' = foldl add_arg (freezeT unfold_thm) (map (cterm_of sign) args);
-    val thm = simplify (simpset_of thy) unfold_thm';
+    val t = term_of (Thm.read_cterm sign (pat, dummyT));
+    val T = fastype_of t;
+    val eq = HOLogic.mk_Trueprop (HOLogic.eq_const T $ t $ Var (("x",0),T));
+    fun head_const (Const ("Cfun.Rep_CFun",_) $ f $ t) = head_const f
+      | head_const (Const (c,_)) = c
+      | head_const _ = sys_error "FIXPAT: function is not declared as constant in theory";
+    val c = head_const t;
+    val unfold_thm = Goals.get_thm thy (c^"_unfold");
+    val thm = prove_goalw_cterm [] (cterm_of sign eq)
+          (fn _ => [stac unfold_thm 1, simp_tac (simpset_of thy) 1]);
+    val _ = print_thm thm;
   in
     (#1 o PureThy.add_thmss [Thm.no_attributes (name, [thm])]) thy
   end;
 
 fun add_fixpat (name,pat) = fix_pat name pat;
 
-(*****************************************************************)
-(*** Parsers ***)
+(*************************************************************************)
+(******************************** Parsers ********************************)
+(*************************************************************************)
 
 local structure P = OuterParse and K = OuterSyntax.Keyword in
 
-val fixrec_decl = Scan.repeat1 P.prop;
+val fixrec_decl = (*P.and_list1*) (Scan.repeat1 P.prop);
 
 (* this builds a parser for a new keyword, fixrec, whose functionality 
 is defined by add_fixrec *)