simplify fixrec proofs for mutually-recursive definitions; generate better fixpoint induction rules
authorhuffman
Mon, 11 May 2009 08:28:09 -0700
changeset 31095 b79d140f6d0b
parent 31094 7d6edb28bdbc
child 31096 e546e15089ef
simplify fixrec proofs for mutually-recursive definitions; generate better fixpoint induction rules
src/HOLCF/Fixrec.thy
src/HOLCF/Tools/fixrec_package.ML
--- a/src/HOLCF/Fixrec.thy	Mon May 11 08:24:35 2009 -0700
+++ b/src/HOLCF/Fixrec.thy	Mon May 11 08:28:09 2009 -0700
@@ -574,15 +574,23 @@
   fixed-point definitions of mutually recursive functions.
 *}
 
-lemma cpair_equalI: "\<lbrakk>x \<equiv> cfst\<cdot>p; y \<equiv> csnd\<cdot>p\<rbrakk> \<Longrightarrow> <x,y> \<equiv> p"
-by (simp add: surjective_pairing_Cprod2)
+lemma Pair_equalI: "\<lbrakk>x \<equiv> fst p; y \<equiv> snd p\<rbrakk> \<Longrightarrow> (x, y) \<equiv> p"
+by simp
 
-lemma cpair_eqD1: "<x,y> = <x',y'> \<Longrightarrow> x = x'"
+lemma Pair_eqD1: "(x, y) = (x', y') \<Longrightarrow> x = x'"
 by simp
 
-lemma cpair_eqD2: "<x,y> = <x',y'> \<Longrightarrow> y = y'"
+lemma Pair_eqD2: "(x, y) = (x', y') \<Longrightarrow> y = y'"
 by simp
 
+lemma def_cont_fix_eq:
+  "\<lbrakk>f \<equiv> fix\<cdot>(Abs_CFun F); cont F\<rbrakk> \<Longrightarrow> f = F f"
+by (simp, subst fix_eq, simp)
+
+lemma def_cont_fix_ind:
+  "\<lbrakk>f \<equiv> fix\<cdot>(Abs_CFun F); cont F; adm P; P \<bottom>; \<And>x. P x \<Longrightarrow> P (F x)\<rbrakk> \<Longrightarrow> P f"
+by (simp add: fix_ind)
+
 text {* lemma for proving rewrite rules *}
 
 lemma ssubst_lhs: "\<lbrakk>t = s; P s = Q\<rbrakk> \<Longrightarrow> P t = Q"
--- a/src/HOLCF/Tools/fixrec_package.ML	Mon May 11 08:24:35 2009 -0700
+++ b/src/HOLCF/Tools/fixrec_package.ML	Mon May 11 08:28:09 2009 -0700
@@ -19,8 +19,8 @@
 structure FixrecPackage :> FIXREC_PACKAGE =
 struct
 
-val fix_eq2 = @{thm fix_eq2};
-val def_fix_ind = @{thm def_fix_ind};
+val def_cont_fix_eq = @{thm def_cont_fix_eq};
+val def_cont_fix_ind = @{thm def_cont_fix_ind};
 
 
 fun fixrec_err s = error ("fixrec definition error:\n" ^ s);
@@ -55,7 +55,7 @@
 fun dest_maybeT (Type(@{type_name "maybe"}, [T])) = T
   | dest_maybeT T = raise TYPE ("dest_maybeT", [T], []);
 
-fun tupleT [] = @{typ "unit"}
+fun tupleT [] = HOLogic.unitT
   | tupleT [T] = T
   | tupleT (T :: Ts) = HOLogic.mk_prodT (T, tupleT Ts);
 
@@ -82,6 +82,10 @@
 fun cabs_const (S, T) =
   Const(@{const_name Abs_CFun}, (S --> T) --> (S ->> T));
 
+fun mk_cabs t =
+  let val T = Term.fastype_of t
+  in cabs_const (Term.domain_type T, Term.range_type T) $ t end
+
 fun mk_capply (t, u) =
   let val (S, T) =
     case Term.fastype_of t of
@@ -93,29 +97,6 @@
 infix 1 ===; val (op ===) = HOLogic.mk_eq;
 infix 9 `  ; val (op `) = mk_capply;
 
-
-fun mk_cpair (t, u) =
-  let val T = Term.fastype_of t
-      val U = Term.fastype_of u
-      val cpairT = T ->> U ->> HOLogic.mk_prodT (T, U)
-  in Const(@{const_name cpair}, cpairT) ` t ` u end;
-
-fun mk_cfst t =
-  let val T = Term.fastype_of t;
-      val (U, _) = HOLogic.dest_prodT T;
-  in Const(@{const_name cfst}, T ->> U) ` t end;
-
-fun mk_csnd t =
-  let val T = Term.fastype_of t;
-      val (_, U) = HOLogic.dest_prodT T;
-  in Const(@{const_name csnd}, T ->> U) ` t end;
-
-fun mk_csplit t =
-  let val (S, TU) = dest_cfunT (Term.fastype_of t);
-      val (T, U) = dest_cfunT TU;
-      val csplitT = (S ->> T ->> U) ->> HOLogic.mk_prodT (S, T) ->> U;
-  in Const(@{const_name csplit}, csplitT) ` t end;
-
 (* builds the expression (LAM v. rhs) *)
 fun big_lambda v rhs =
   cabs_const (Term.fastype_of v, Term.fastype_of rhs) $ Term.lambda v rhs;
@@ -124,17 +105,6 @@
 fun big_lambdas [] rhs = rhs
   | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
 
-(* builds the expression (LAM <v1,v2,..,vn>. rhs) *)
-fun lambda_ctuple [] rhs = big_lambda (Free("unit", HOLogic.unitT)) rhs
-  | lambda_ctuple (v::[]) rhs = big_lambda v rhs
-  | lambda_ctuple (v::vs) rhs =
-      mk_csplit (big_lambda v (lambda_ctuple vs rhs));
-
-(* builds the expression <v1,v2,..,vn> *)
-fun mk_ctuple [] = @{term "UU::unit"}
-|   mk_ctuple (t::[]) = t
-|   mk_ctuple (t::ts) = mk_cpair (t, mk_ctuple ts);
-
 fun mk_return t =
   let val T = Term.fastype_of t
   in Const(@{const_name Fixrec.return}, T ->> maybeT T) ` t end;
@@ -157,6 +127,25 @@
   let val (T, _) = dest_cfunT (Term.fastype_of t)
   in Const(@{const_name fix}, (T ->> T) ->> T) ` t end;
 
+fun mk_cont t =
+  let val T = Term.fastype_of t
+  in Const(@{const_name cont}, T --> HOLogic.boolT) $ t end;
+
+val mk_fst = HOLogic.mk_fst
+val mk_snd = HOLogic.mk_snd
+
+(* builds the expression (v1,v2,..,vn) *)
+fun mk_tuple [] = HOLogic.unit
+|   mk_tuple (t::[]) = t
+|   mk_tuple (t::ts) = HOLogic.mk_prod (t, mk_tuple ts);
+
+(* builds the expression (%(v1,v2,..,vn). rhs) *)
+fun lambda_tuple [] rhs = Term.lambda (Free("unit", HOLogic.unitT)) rhs
+  | lambda_tuple (v::[]) rhs = Term.lambda v rhs
+  | lambda_tuple (v::vs) rhs =
+      HOLogic.mk_split (Term.lambda v (lambda_tuple vs rhs));
+
+
 (*************************************************************************)
 (************* fixed-point definitions and unfolding theorems ************)
 (*************************************************************************)
@@ -166,40 +155,48 @@
   (spec : (Attrib.binding * term) list)
   (lthy : local_theory) =
   let
+    val thy = ProofContext.theory_of lthy;
     val names = map (Binding.name_of o fst o fst) fixes;
     val all_names = space_implode "_" names;
     val (lhss,rhss) = ListPair.unzip (map (dest_eqs o snd) spec);
-    val fixpoint = mk_fix (lambda_ctuple lhss (mk_ctuple rhss));
+    val functional = lambda_tuple lhss (mk_tuple rhss);
+    val fixpoint = mk_fix (mk_cabs functional);
     
+    val cont_thm =
+      Goal.prove lthy [] [] (mk_trp (mk_cont functional))
+        (K (simp_tac (local_simpset_of lthy) 1));
+
     fun one_def (l as Free(n,_)) r =
           let val b = Long_Name.base_name n
           in ((Binding.name (b^"_def"), []), r) end
       | one_def _ _ = fixrec_err "fixdefs: lhs not of correct form";
     fun defs [] _ = []
       | defs (l::[]) r = [one_def l r]
-      | defs (l::ls) r = one_def l (mk_cfst r) :: defs ls (mk_csnd r);
+      | defs (l::ls) r = one_def l (mk_fst r) :: defs ls (mk_snd r);
     val fixdefs = defs lhss fixpoint;
     val define_all = fold_map (LocalTheory.define Thm.definitionK);
     val (fixdef_thms : (term * (string * thm)) list, lthy') = lthy
       |> define_all (map (apfst fst) fixes ~~ fixdefs);
-    fun cpair_equalI (thm1, thm2) = @{thm cpair_equalI} OF [thm1, thm2];
-    val ctuple_fixdef_thm = foldr1 cpair_equalI (map (snd o snd) fixdef_thms);
-    val ctuple_induct_thm = ctuple_fixdef_thm RS def_fix_ind;
-    val ctuple_unfold_thm =
-      Goal.prove lthy' [] [] (mk_trp (mk_ctuple lhss === mk_ctuple rhss))
-        (fn _ => EVERY [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
-                   simp_tac (local_simpset_of lthy') 1]);
+    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
+    val tuple_fixdef_thm = foldr1 pair_equalI (map (snd o snd) fixdef_thms);
+    val P = Var (("P", 0), map Term.fastype_of lhss ---> HOLogic.boolT);
+    val predicate = lambda_tuple lhss (list_comb (P, lhss));
+    val tuple_induct_thm = (def_cont_fix_ind OF [tuple_fixdef_thm, cont_thm])
+      |> Drule.instantiate' [] [SOME (Thm.cterm_of thy predicate)]
+      |> LocalDefs.unfold lthy @{thms split_paired_all split_conv split_strict};
+    val tuple_unfold_thm = (def_cont_fix_eq OF [tuple_fixdef_thm, cont_thm])
+      |> LocalDefs.unfold lthy' @{thms split_conv};
     fun unfolds [] thm = []
       | unfolds (n::[]) thm = [(n^"_unfold", thm)]
       | unfolds (n::ns) thm = let
-          val thmL = thm RS @{thm cpair_eqD1};
-          val thmR = thm RS @{thm cpair_eqD2};
+          val thmL = thm RS @{thm Pair_eqD1};
+          val thmR = thm RS @{thm Pair_eqD2};
         in (n^"_unfold", thmL) :: unfolds ns thmR end;
-    val unfold_thms = unfolds names ctuple_unfold_thm;
+    val unfold_thms = unfolds names tuple_unfold_thm;
     fun mk_note (n, thm) = ((Binding.name n, []), [thm]);
     val (thmss, lthy'') = lthy'
       |> fold_map (LocalTheory.note Thm.theoremK o mk_note)
-        ((all_names ^ "_induct", ctuple_induct_thm) :: unfold_thms);
+        ((all_names ^ "_induct", tuple_induct_thm) :: unfold_thms);
   in
     (lthy'', names, fixdef_thms, map snd unfold_thms)
   end;