src/HOLCF/fixrec_package.ML
changeset 16226 c17ac524d866
child 16387 67f6044c1891
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOLCF/fixrec_package.ML	Sat Jun 04 02:10:19 2005 +0200
@@ -0,0 +1,268 @@
+(*  Title:      HOLCF/fixrec_package.ML
+    ID:         $Id$
+    Author:     Amber Telfer and Brian Huffman
+
+Recursive function definition package for HOLCF.
+*)
+
+signature FIXREC_PACKAGE =
+sig
+  val add_fixrec: string list -> theory -> theory
+  val add_fixpat: string * string -> theory -> theory
+end;
+
+structure FixrecPackage: FIXREC_PACKAGE =
+struct
+
+local
+open ThyParse in
+
+(* ->> is taken from holcf_logic.ML *)
+(* TODO: fix dependencies so we can import HOLCFLogic here *)
+infixr 6 ->>;
+fun S ->> T = Type ("Cfun.->",[S,T]);
+
+(* extern_name is taken from domain/library.ML *)
+fun extern_name con = case Symbol.explode con of 
+		   ("o"::"p"::" "::rest) => implode rest
+		   | _ => con;
+
+(***************  This is the building functional **************)
+
+(* converts string proposition to a cterm *)
+fun all eqs thy = (let val sign = sign_of (thy) in 
+                  ((term_of o (Thm.read_cterm sign) o rpair propT) eqs) end);
+
+(* splits a cterm into the right and lefthand sides of equality *)
+fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
+  | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs))    = (lhs,rhs)
+  | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t);
+
+(* building fixpoint functional def for an equation with only 
+   variables as parameters *)
+fun func1 (lhs as Const(name,T), rhs) =
+  let
+    val basename = Sign.base_name name;
+    val funcT = T ->> T;
+    val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $
+          Abs (basename, T, abstract_over (lhs,rhs));
+    val fix_type = funcT ->> T;
+    val fix_const = Const ("Fix.fix", fix_type);
+    val func_type = fix_type --> funcT --> T;
+    val rhs' = Const ("Cfun.Rep_CFun",func_type)$fix_const$functional;
+  in
+    (name, (basename^"_fixdef", equals T $ lhs $ rhs'))
+  end
+  | func1 t = sys_error "func1: not of correct form";
+
+(**************************************************************)
+(* these are helpful functions copied from HOLCF/domain/library.ML *)
+fun %: s = Free(s,dummyT);
+fun %%: s = Const(s,dummyT);
+infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
+infix 1 ===; fun S === T = %%:"op =" $ S $ T;
+infix 9 `  ; fun f`  x = %%:"Rep_CFun" $ f $ x;
+
+(* infers the type of a term *)
+fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
+
+(*****************************************************************)
+(* monadic notation and pattern matching *)
+(*****************************************************************)
+
+(* these 3 functions strip off parameters and destruct constructors *)
+fun strip_cpair (Const("Cfun.Rep_CFun",_) $
+      (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
+        b :: strip_cpair r
+  | strip_cpair c = [c];
+
+fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
+
+fun big_lambdas [] rhs = rhs
+  | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
+
+(* builds a big lamdba expression with a tuple *)
+fun lambda_tuple [] rhs = big_lambda (%:"unit") rhs
+  | lambda_tuple [v] rhs = big_lambda v rhs
+  | lambda_tuple (v::vs) rhs =
+      %%:"Cprod.csplit"`(big_lambda v (lambda_tuple vs rhs));
+
+fun add_names (Const(a,_), bs) = NameSpace.base a ins_string bs
+  | add_names (Free(a,_), bs) = a ins_string bs
+  | add_names (f$u, bs) = add_names (f, add_names(u, bs))
+  | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs)
+  | add_names (_, bs) = bs;
+fun add_terms ts xs = foldr add_names xs ts;
+
+(* builds a monadic term for matching a constructor pattern *)
+fun pre_build (Const("Cfun.Rep_CFun",_)$f$(v as Free(n,T))) rhs vs taken =
+      pre_build f rhs (v::vs) taken
+  | pre_build (Const("Cfun.Rep_CFun",_)$f$x) rhs vs taken =
+      let val (rhs', v, taken') = pre_build x rhs [] taken;
+      in
+        pre_build f rhs' (v::vs) taken'
+      end
+  | pre_build (Const(c,T)) rhs vs taken =
+      let
+        val n = variant taken "v";
+        fun result_type (Type(_,[_,T])) (x::xs) = result_type T xs
+          | result_type T _ = T;
+        val v = Free(n, result_type T vs);
+        val m = "match_"^(extern_name(NameSpace.base c));
+        val k = lambda_tuple vs rhs;
+      in
+        (%%:"bind"`(%%:m`v)`k, v, n::taken)
+      end;
+
+(* builds a monadic term for matching a function definition pattern *)
+(* returns (name, arity, matcher) *)
+fun building (Const("Cfun.Rep_CFun", _)$f$(v as Free(n,T))) rhs vs taken =
+      building f rhs (v::vs) taken
+  | building (Const("Cfun.Rep_CFun", _)$f$x) rhs vs taken =
+      let
+        val (rhs', v, taken') = pre_build x rhs [] taken;
+      in
+        building f rhs' (v::vs) taken'
+      end
+  | building (c as Const(_,_)) rhs vs taken = (c, length vs, big_lambdas vs rhs)
+  | building _ _ _ _ = sys_error "function is not declared as constant in theory";
+
+fun match_eq thy f = 
+  let
+    val e = (all f thy);
+    val (lhs,rhs) = dest_eqs e;
+    val (Const(n,_), a, t) = building lhs (%%:"return"`rhs) [] (add_terms [e] []);
+  in
+    (n, a, t)
+  end;
+
+(* returns the sum (using +++) of the terms in ms *)
+(* also applies "run" to the result! *)
+fun fatbar arity ms =
+  let
+    fun unLAM 0 t = t
+      | unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t
+      | unLAM _ _ = sys_error "FIXREC: internal error, not enough LAMs";
+    fun reLAM 0 t = t
+      | reLAM n t = reLAM (n-1) (%%:"Abs_CFun" $ Abs("",dummyT,t));
+    fun mplus (x,y) = %%:"Fixrec.mplus"`x`y;
+    val msum = foldr1 mplus (map (unLAM arity) ms);
+  in
+    reLAM arity (%%:"Fixrec.run"`msum)
+  end;
+
+(***************************************************************)
+(*** Proving associated theorems ***)
+
+fun prove_thm thy unfold_thm x =
+  let
+    val ss = simpset_of thy;
+    val thm = prove_goalw thy [] x (fn _ => [SOLVE(stac unfold_thm 1 THEN simp_tac ss 1)])
+      handle _ => sys_error (x^" :: proof failed on this equation.");
+  in thm end;
+
+(* this proves that each equation is a theorem *)
+fun prove_list thy unfold_thm [] = []
+  | prove_list thy unfold_thm (x::xs) =
+      prove_thm thy unfold_thm x :: prove_list thy unfold_thm xs;
+
+(* this proves the def without fix is a theorem, this uses the fixpoint def *)
+fun make_simp name eqs ct fixdef_thm thy' = 
+  let
+    val basename = NameSpace.base name;
+    val ss = simpset_of thy';
+    val eq_thm = fixdef_thm RS fix_eq2;
+    val unfold_thm = prove_goalw_cterm [] ct
+      (fn _ => [(rtac (eq_thm RS trans) 1) THEN (simp_tac ss 1)]);
+    val ind_thm = fixdef_thm RS def_fix_ind;
+    val rew_thms = prove_list thy' unfold_thm eqs;
+    val thmss =
+      [ (basename^"_unfold", [unfold_thm])
+      , (basename^"_ind", [ind_thm])
+      , (basename^"_rews", rew_thms) ]
+  in
+    (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
+  end;
+
+fun unzip3 [] = ([],[],[])
+  | unzip3 ((x,y,z)::ts) =
+      let val (xs,ys,zs) = unzip3 ts
+      in (x::xs, y::ys, z::zs) end;
+
+(* this is the main processing function *)
+fun pat_fun eqs thy = 
+  let
+    val ((n::names),(a::arities),mats) = unzip3 (map (match_eq thy) eqs);
+    val cname = if forall (fn x => n=x) names then n
+          else sys_error "PAT_FUN: all equations must define the same function";
+    val arity = if forall (fn x => a=x) arities then a
+          else sys_error "FIXREC: all equations must have the same arity";
+    val msum = fatbar arity mats;
+    val v = variant (add_term_names (msum,[])) "v";
+    val funcc = infer (HOLogic.mk_Trueprop (%%:cname === msum)) thy;
+    val (name', fixdef_name_term) = func1 (dest_eqs funcc);
+    val (thy', [fixdef_thm]) =
+      PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
+  in
+    make_simp name' eqs (cterm_of (sign_of thy') funcc) fixdef_thm thy'
+  end;
+
+(***************************************************************)
+
+(* this calls the main processing function and then returns the new state *)
+fun add_fixrec eqs = pat_fun eqs;
+
+(*****************************************************************)
+(*** Fixpat ***)
+
+(* like Term.strip_comb, but with continuous application *)
+fun strip_cdot u : term * term list =
+  let fun stripc (Const("Cfun.Rep_CFun",_)$f$t, ts) = stripc (f, t::ts)
+        | stripc x = x
+  in stripc(u,[]) end;
+
+fun fix_pat name pat thy = 
+  let
+    val sign = sign_of thy;
+    val ct = Thm.read_cterm sign (pat, dummyT);
+    val (Const (f,_), args) = strip_cdot (term_of ct);
+    val unfold_thm = Goals.get_thm thy (f^"_unfold");
+    fun add_arg (arg,thm) = instantiate' [] [SOME arg] (thm RS cfun_fun_cong);
+    val unfold_thm' = foldl add_arg (freezeT unfold_thm) (map (cterm_of sign) args);
+    val thm = simplify (simpset_of thy) unfold_thm';
+  in
+    (#1 o PureThy.add_thmss [Thm.no_attributes (name, [thm])]) thy
+  end;
+
+fun add_fixpat (name,pat) = fix_pat name pat;
+
+(*****************************************************************)
+(*** Parsers ***)
+
+local structure P = OuterParse and K = OuterSyntax.Keyword in
+
+val fixrec_decl = Scan.repeat1 P.prop;
+
+(* this builds a parser for a new keyword, fixrec, whose functionality 
+is defined by add_fixrec *)
+val fixrecP =
+  OuterSyntax.command "fixrec" "parser for fixrec functions" K.thy_decl
+    (fixrec_decl >> (Toplevel.theory o add_fixrec));
+
+(* this adds the parser for fixrec to the syntax *)
+val _ = OuterSyntax.add_parsers [fixrecP];
+
+(* fixpat parser *)
+val fixpat_decl = P.name -- P.prop;
+
+val fixpatP =
+  OuterSyntax.command "fixpat" "testing out this parser" K.thy_decl
+    (fixpat_decl >> (Toplevel.theory o add_fixpat));
+
+val _ = OuterSyntax.add_parsers [fixpatP];
+
+end; (* local structure *)
+
+end; (* local open *)
+
+end; (* struct *)