fixrec package now handles mutually-recursive definitions
authorhuffman
Wed, 15 Jun 2005 20:50:38 +0200
changeset 16401 57c35ede00b9
parent 16400 f2ab5797bbd0
child 16402 36f41d5e3b3e
fixrec package now handles mutually-recursive definitions
src/HOLCF/Fixrec.thy
src/HOLCF/fixrec_package.ML
--- a/src/HOLCF/Fixrec.thy	Wed Jun 15 14:59:25 2005 +0200
+++ b/src/HOLCF/Fixrec.thy	Wed Jun 15 20:50:38 2005 +0200
@@ -134,6 +134,27 @@
   "match_up\<cdot>\<bottom> = \<bottom>"
 by (simp_all add: match_up_def)
 
+subsection {* Mutual recursion *}
+
+text {*
+  The following rules are used to prove unfolding theorems from
+  fixed-point definitions of mutually recursive functions.
+*}
+
+lemma cpair_equalI: "\<lbrakk>x \<equiv> cfst\<cdot>p; y \<equiv> csnd\<cdot>p\<rbrakk> \<Longrightarrow> <x,y> \<equiv> p"
+by (simp add: surjective_pairing_Cprod2)
+
+lemma cpair_eqD1: "<x,y> = <x',y'> \<Longrightarrow> x = x'"
+by simp
+
+lemma cpair_eqD2: "<x,y> = <x',y'> \<Longrightarrow> y = y'"
+by simp
+
+ML {*
+val cpair_equalI = thm "cpair_equalI";
+val cpair_eqD1 = thm "cpair_eqD1";
+val cpair_eqD2 = thm "cpair_eqD2";
+*}
 
 subsection {* Intitializing the fixrec package *}
 
--- a/src/HOLCF/fixrec_package.ML	Wed Jun 15 14:59:25 2005 +0200
+++ b/src/HOLCF/fixrec_package.ML	Wed Jun 15 20:50:38 2005 +0200
@@ -7,7 +7,7 @@
 
 signature FIXREC_PACKAGE =
 sig
-  val add_fixrec: string list -> theory -> theory
+  val add_fixrec: string list list -> theory -> theory
   val add_fixpat: string * string -> theory -> theory
 end;
 
@@ -27,6 +27,8 @@
 		   ("o"::"p"::" "::rest) => implode rest
 		   | _ => con;
 
+val mk_trp = HOLogic.mk_Trueprop;
+
 (* splits a cterm into the right and lefthand sides of equality *)
 fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
   | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs))    = (lhs,rhs)
@@ -37,57 +39,92 @@
 fun %%: s = Const(s,dummyT);
 infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
 infix 1 ===; fun S === T = %%:"op =" $ S $ T;
-infix 9 `  ; fun f ` x = %%:"Rep_CFun" $ f $ x;
+infix 9 `  ; fun f ` x = %%:"Cfun.Rep_CFun" $ f $ x;
 
 (* infers the type of a term *)
-fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
+(* similar to Theory.inferT_axm, but allows any type *)
+fun infer sg t =
+  fst (Sign.infer_types (Sign.pp sg) sg (K NONE) (K NONE) [] true ([t],dummyT));
+
+(* The next few functions build continuous lambda abstractions *)
+
+(* Similar to Term.lambda, but allows abstraction over constants *)
+fun lambda' (v as Free (x, T)) t = Abs (x, T, abstract_over (v, t))
+  | lambda' (v as Var ((x, _), T)) t = Abs (x, T, abstract_over (v, t))
+  | lambda' (v as Const (x, T)) t = Abs (Sign.base_name x, T, abstract_over (v, t))
+  | lambda' v t = raise TERM ("lambda'", [v, t]);
+
+(* builds the expression (LAM v. rhs) *)
+fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda' v rhs);
+
+(* builds the expression (LAM v1 v2 .. vn. rhs) *)
+fun big_lambdas [] rhs = rhs
+  | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
+
+(* builds the expression (LAM <v1,v2,..,vn>. rhs) *)
+fun lambda_ctuple [] rhs = big_lambda (%:"unit") rhs
+  | lambda_ctuple (v::[]) rhs = big_lambda v rhs
+  | lambda_ctuple (v::vs) rhs =
+      %%:"Cprod.csplit"`(big_lambda v (lambda_ctuple vs rhs));
+
+(* builds the expression <v1,v2,..,vn> *)
+fun mk_ctuple [] = %%:"UU"
+|   mk_ctuple (t::[]) = t
+|   mk_ctuple (t::ts) = %%:"Cprod.cpair"`t`(mk_ctuple ts);
 
 (*************************************************************************)
-(************ fixed-point definitions and unfolding theorems *************)
+(************************ fixed-point definitions ************************)
 (*************************************************************************)
 
-fun func1 (lhs as Const(name,T), rhs) =
+fun add_fixdefs eqs thy =
   let
-    val basename = Sign.base_name name;
-    val funcT = T ->> T;
-    val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $
-          Abs (basename, T, abstract_over (lhs,rhs));
-    val fix_type = funcT ->> T;
-    val fix_const = Const ("Fix.fix", fix_type);
-    val func_type = fix_type --> funcT --> T;
-    val rhs' = Const ("Cfun.Rep_CFun",func_type)$fix_const$functional;
+    val (lhss,rhss) = ListPair.unzip (map dest_eqs eqs);
+    val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss);
+    
+    fun one_def (l as Const(n,T)) r =
+          let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end
+      | one_def _ _ = sys_error "fixdefs: lhs not of correct form";
+    fun defs [] _ = []
+      | defs (l::[]) r = [one_def l r]
+      | defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r);
+    val (names, pre_fixdefs) = ListPair.unzip (defs lhss fixpoint);
+    
+    val fixdefs = map (inferT_axm (sign_of thy)) pre_fixdefs;
+    val (thy', fixdef_thms) =
+      PureThy.add_defs_i false (map Thm.no_attributes fixdefs) thy;
+    val ctuple_fixdef_thm = foldr1 (fn (x,y) => cpair_equalI OF [x,y]) fixdef_thms;
+    
+    fun mk_cterm t = let val sg' = sign_of thy' in cterm_of sg' (infer sg' t) end;
+    val ctuple_unfold_ct = mk_cterm (mk_trp (mk_ctuple lhss === mk_ctuple rhss));
+    val ctuple_unfold_thm = prove_goalw_cterm [] ctuple_unfold_ct
+          (fn _ => [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
+                    simp_tac (simpset_of thy') 1]);
+    val ctuple_induct_thm =
+          (space_implode "_" names ^ "_induct", [ctuple_fixdef_thm RS def_fix_ind]);
+    
+    fun unfolds [] thm = []
+      | unfolds (n::[]) thm = [(n^"_unfold", [thm])]
+      | unfolds (n::ns) thm = let
+          val thmL = thm RS cpair_eqD1;
+          val thmR = thm RS cpair_eqD2;
+        in (n^"_unfold", [thmL]) :: unfolds ns thmR end;
+    val unfold_thmss = unfolds names ctuple_unfold_thm;
+    val thmss = ctuple_induct_thm :: unfold_thmss;
+    val (thy'', _) = PureThy.add_thmss (map Thm.no_attributes thmss) thy';
   in
-    (name, (basename^"_fixdef", equals T $ lhs $ rhs'))
-  end
-  | func1 t = sys_error "func1: not of correct form";
+    (thy'', names, fixdef_thms, List.concat (map snd unfold_thmss))
+  end;
 
 (*************************************************************************)
 (*********** monadic notation and pattern matching compilation ***********)
 (*************************************************************************)
 
-(* these 3 functions strip off parameters and destruct constructors *)
-(*
-fun strip_cpair (Const("Cfun.Rep_CFun",_) $
-      (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
-        b :: strip_cpair r
-  | strip_cpair c = [c];
-*)
-fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
-
-fun big_lambdas [] rhs = rhs
-  | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
+fun add_names (Const(a,_), bs) = Sign.base_name a ins_string bs
+  | add_names (Free(a,_) , bs) = a ins_string bs
+  | add_names (f $ u     , bs) = add_names (f, add_names(u, bs))
+  | add_names (Abs(a,_,t), bs) = add_names (t, a ins_string bs)
+  | add_names (_         , bs) = bs;
 
-(* builds a big lamdba expression with a tuple *)
-fun lambda_tuple [] rhs = big_lambda (%:"unit") rhs
-  | lambda_tuple [v] rhs = big_lambda v rhs
-  | lambda_tuple (v::vs) rhs =
-      %%:"Cprod.csplit"`(big_lambda v (lambda_tuple vs rhs));
-
-fun add_names (Const(a,_), bs) = NameSpace.base a ins_string bs
-  | add_names (Free(a,_), bs) = a ins_string bs
-  | add_names (f$u, bs) = add_names (f, add_names(u, bs))
-  | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs)
-  | add_names (_, bs) = bs;
 fun add_terms ts xs = foldr add_names xs ts;
 
 (* builds a monadic term for matching a constructor pattern *)
@@ -104,8 +141,8 @@
         fun result_type (Type("Cfun.->",[_,T])) (x::xs) = result_type T xs
           | result_type T _ = T;
         val v = Free(n, result_type T vs);
-        val m = "match_"^(extern_name(NameSpace.base c));
-        val k = lambda_tuple vs rhs;
+        val m = "match_"^(extern_name(Sign.base_name c));
+        val k = lambda_ctuple vs rhs;
       in
         (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
       end;
@@ -159,14 +196,14 @@
           else sys_error "FIXREC: all equations must have the same arity";
     val rhs = fatbar arity mats;
   in
-    HOLogic.mk_Trueprop (%%:cname === rhs)
+    mk_trp (%%:cname === rhs)
   end;
 
 (*************************************************************************)
 (********************** Proving associated theorems **********************)
 (*************************************************************************)
 
-fun prove_thm thy unfold_thm ct =
+fun prove_rew thy unfold_thm ct =
   let
     val ss = simpset_of thy;
     val thm = prove_goalw_cterm [] ct
@@ -175,18 +212,24 @@
   in thm end;
 
 (* this proves that each equation is a theorem *)
-fun prove_list thy unfold_thm [] = []
-  | prove_list thy unfold_thm (x::xs) =
-      prove_thm thy unfold_thm x :: prove_list thy unfold_thm xs;
+fun prove_rews thy (unfold_thm,cts) = map (prove_rew thy unfold_thm) cts;
+
+(* proves the pattern matching equations as theorems, using unfold *)
+fun make_simps names unfold_thms ctss thy = 
+  let
+    val thm_names = map (fn name => name^"_rews") names;
+    val rew_thmss = ListPair.map (prove_rews thy) (unfold_thms, ctss);
+    val thmss = ListPair.zip (thm_names, rew_thmss);
+  in
+    (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy
+  end;
 
 (* this proves the def without fix is a theorem, this uses the fixpoint def *)
+(*
 fun make_simp name eqs ct fixdef_thm thy' = 
   let
-    val basename = NameSpace.base name;
     val ss = simpset_of thy';
     val eq_thm = fixdef_thm RS fix_eq2;
-    val unfold_thm = prove_goalw_cterm [] ct
-      (fn _ => [(rtac (eq_thm RS trans) 1) THEN (simp_tac ss 1)]);
     val ind_thm = fixdef_thm RS def_fix_ind;
     val rew_thms = prove_list thy' unfold_thm eqs;
     val thmss =
@@ -196,25 +239,21 @@
   in
     (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
   end;
-
+*)
 (*************************************************************************)
 (************************* Main fixrec function **************************)
 (*************************************************************************)
 
 (* this calls the main processing function and then returns the new state *)
-fun add_fixrec strs thy =
+fun add_fixrec strss thy =
   let
     val sg = sign_of thy;
-    val cts = map (Thm.read_cterm sg o rpair propT) strs;
-    val eqs = map term_of cts;
-    val funcc = infer (compile_pats eqs) thy;
-    val _ = print_cterm (cterm_of sg funcc);
-    val (name', fixdef_name_term) = func1 (dest_eqs funcc);
-    val (thy', [fixdef_thm]) =
-      PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
-    val ct = cterm_of (sign_of thy') funcc;
+    val ctss = map (map (Thm.read_cterm sg o rpair propT)) strss;
+    val tss = map (map term_of) ctss;
+    val ts' = map (fn ts => infer sg (compile_pats ts)) tss;
+    val (thy', names, fixdef_thms, unfold_thms) = add_fixdefs ts' thy;
   in
-    make_simp name' cts ct fixdef_thm thy'
+    make_simps names unfold_thms ctss thy'
   end;
 
 (*************************************************************************)
@@ -226,7 +265,7 @@
     val sign = sign_of thy;
     val t = term_of (Thm.read_cterm sign (pat, dummyT));
     val T = fastype_of t;
-    val eq = HOLogic.mk_Trueprop (HOLogic.eq_const T $ t $ Var (("x",0),T));
+    val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T));
     fun head_const (Const ("Cfun.Rep_CFun",_) $ f $ t) = head_const f
       | head_const (Const (c,_)) = c
       | head_const _ = sys_error "FIXPAT: function is not declared as constant in theory";
@@ -247,7 +286,7 @@
 
 local structure P = OuterParse and K = OuterSyntax.Keyword in
 
-val fixrec_decl = (*P.and_list1*) (Scan.repeat1 P.prop);
+val fixrec_decl = P.and_list1 (Scan.repeat1 P.prop);
 
 (* this builds a parser for a new keyword, fixrec, whose functionality 
 is defined by add_fixrec *)