simplify fixrec proofs for mutually-recursive definitions; generate better fixpoint induction rules
authorhuffman
Mon May 11 08:28:09 2009 -0700 (2009-05-11)
changeset 31095b79d140f6d0b
parent 31094 7d6edb28bdbc
child 31096 e546e15089ef
simplify fixrec proofs for mutually-recursive definitions; generate better fixpoint induction rules
src/HOLCF/Fixrec.thy
src/HOLCF/Tools/fixrec_package.ML
     1.1 --- a/src/HOLCF/Fixrec.thy	Mon May 11 08:24:35 2009 -0700
     1.2 +++ b/src/HOLCF/Fixrec.thy	Mon May 11 08:28:09 2009 -0700
     1.3 @@ -574,15 +574,23 @@
     1.4    fixed-point definitions of mutually recursive functions.
     1.5  *}
     1.6  
     1.7 -lemma cpair_equalI: "\<lbrakk>x \<equiv> cfst\<cdot>p; y \<equiv> csnd\<cdot>p\<rbrakk> \<Longrightarrow> <x,y> \<equiv> p"
     1.8 -by (simp add: surjective_pairing_Cprod2)
     1.9 +lemma Pair_equalI: "\<lbrakk>x \<equiv> fst p; y \<equiv> snd p\<rbrakk> \<Longrightarrow> (x, y) \<equiv> p"
    1.10 +by simp
    1.11  
    1.12 -lemma cpair_eqD1: "<x,y> = <x',y'> \<Longrightarrow> x = x'"
    1.13 +lemma Pair_eqD1: "(x, y) = (x', y') \<Longrightarrow> x = x'"
    1.14  by simp
    1.15  
    1.16 -lemma cpair_eqD2: "<x,y> = <x',y'> \<Longrightarrow> y = y'"
    1.17 +lemma Pair_eqD2: "(x, y) = (x', y') \<Longrightarrow> y = y'"
    1.18  by simp
    1.19  
    1.20 +lemma def_cont_fix_eq:
    1.21 +  "\<lbrakk>f \<equiv> fix\<cdot>(Abs_CFun F); cont F\<rbrakk> \<Longrightarrow> f = F f"
    1.22 +by (simp, subst fix_eq, simp)
    1.23 +
    1.24 +lemma def_cont_fix_ind:
    1.25 +  "\<lbrakk>f \<equiv> fix\<cdot>(Abs_CFun F); cont F; adm P; P \<bottom>; \<And>x. P x \<Longrightarrow> P (F x)\<rbrakk> \<Longrightarrow> P f"
    1.26 +by (simp add: fix_ind)
    1.27 +
    1.28  text {* lemma for proving rewrite rules *}
    1.29  
    1.30  lemma ssubst_lhs: "\<lbrakk>t = s; P s = Q\<rbrakk> \<Longrightarrow> P t = Q"
     2.1 --- a/src/HOLCF/Tools/fixrec_package.ML	Mon May 11 08:24:35 2009 -0700
     2.2 +++ b/src/HOLCF/Tools/fixrec_package.ML	Mon May 11 08:28:09 2009 -0700
     2.3 @@ -19,8 +19,8 @@
     2.4  structure FixrecPackage :> FIXREC_PACKAGE =
     2.5  struct
     2.6  
     2.7 -val fix_eq2 = @{thm fix_eq2};
     2.8 -val def_fix_ind = @{thm def_fix_ind};
     2.9 +val def_cont_fix_eq = @{thm def_cont_fix_eq};
    2.10 +val def_cont_fix_ind = @{thm def_cont_fix_ind};
    2.11  
    2.12  
    2.13  fun fixrec_err s = error ("fixrec definition error:\n" ^ s);
    2.14 @@ -55,7 +55,7 @@
    2.15  fun dest_maybeT (Type(@{type_name "maybe"}, [T])) = T
    2.16    | dest_maybeT T = raise TYPE ("dest_maybeT", [T], []);
    2.17  
    2.18 -fun tupleT [] = @{typ "unit"}
    2.19 +fun tupleT [] = HOLogic.unitT
    2.20    | tupleT [T] = T
    2.21    | tupleT (T :: Ts) = HOLogic.mk_prodT (T, tupleT Ts);
    2.22  
    2.23 @@ -82,6 +82,10 @@
    2.24  fun cabs_const (S, T) =
    2.25    Const(@{const_name Abs_CFun}, (S --> T) --> (S ->> T));
    2.26  
    2.27 +fun mk_cabs t =
    2.28 +  let val T = Term.fastype_of t
    2.29 +  in cabs_const (Term.domain_type T, Term.range_type T) $ t end
    2.30 +
    2.31  fun mk_capply (t, u) =
    2.32    let val (S, T) =
    2.33      case Term.fastype_of t of
    2.34 @@ -93,29 +97,6 @@
    2.35  infix 1 ===; val (op ===) = HOLogic.mk_eq;
    2.36  infix 9 `  ; val (op `) = mk_capply;
    2.37  
    2.38 -
    2.39 -fun mk_cpair (t, u) =
    2.40 -  let val T = Term.fastype_of t
    2.41 -      val U = Term.fastype_of u
    2.42 -      val cpairT = T ->> U ->> HOLogic.mk_prodT (T, U)
    2.43 -  in Const(@{const_name cpair}, cpairT) ` t ` u end;
    2.44 -
    2.45 -fun mk_cfst t =
    2.46 -  let val T = Term.fastype_of t;
    2.47 -      val (U, _) = HOLogic.dest_prodT T;
    2.48 -  in Const(@{const_name cfst}, T ->> U) ` t end;
    2.49 -
    2.50 -fun mk_csnd t =
    2.51 -  let val T = Term.fastype_of t;
    2.52 -      val (_, U) = HOLogic.dest_prodT T;
    2.53 -  in Const(@{const_name csnd}, T ->> U) ` t end;
    2.54 -
    2.55 -fun mk_csplit t =
    2.56 -  let val (S, TU) = dest_cfunT (Term.fastype_of t);
    2.57 -      val (T, U) = dest_cfunT TU;
    2.58 -      val csplitT = (S ->> T ->> U) ->> HOLogic.mk_prodT (S, T) ->> U;
    2.59 -  in Const(@{const_name csplit}, csplitT) ` t end;
    2.60 -
    2.61  (* builds the expression (LAM v. rhs) *)
    2.62  fun big_lambda v rhs =
    2.63    cabs_const (Term.fastype_of v, Term.fastype_of rhs) $ Term.lambda v rhs;
    2.64 @@ -124,17 +105,6 @@
    2.65  fun big_lambdas [] rhs = rhs
    2.66    | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
    2.67  
    2.68 -(* builds the expression (LAM <v1,v2,..,vn>. rhs) *)
    2.69 -fun lambda_ctuple [] rhs = big_lambda (Free("unit", HOLogic.unitT)) rhs
    2.70 -  | lambda_ctuple (v::[]) rhs = big_lambda v rhs
    2.71 -  | lambda_ctuple (v::vs) rhs =
    2.72 -      mk_csplit (big_lambda v (lambda_ctuple vs rhs));
    2.73 -
    2.74 -(* builds the expression <v1,v2,..,vn> *)
    2.75 -fun mk_ctuple [] = @{term "UU::unit"}
    2.76 -|   mk_ctuple (t::[]) = t
    2.77 -|   mk_ctuple (t::ts) = mk_cpair (t, mk_ctuple ts);
    2.78 -
    2.79  fun mk_return t =
    2.80    let val T = Term.fastype_of t
    2.81    in Const(@{const_name Fixrec.return}, T ->> maybeT T) ` t end;
    2.82 @@ -157,6 +127,25 @@
    2.83    let val (T, _) = dest_cfunT (Term.fastype_of t)
    2.84    in Const(@{const_name fix}, (T ->> T) ->> T) ` t end;
    2.85  
    2.86 +fun mk_cont t =
    2.87 +  let val T = Term.fastype_of t
    2.88 +  in Const(@{const_name cont}, T --> HOLogic.boolT) $ t end;
    2.89 +
    2.90 +val mk_fst = HOLogic.mk_fst
    2.91 +val mk_snd = HOLogic.mk_snd
    2.92 +
    2.93 +(* builds the expression (v1,v2,..,vn) *)
    2.94 +fun mk_tuple [] = HOLogic.unit
    2.95 +|   mk_tuple (t::[]) = t
    2.96 +|   mk_tuple (t::ts) = HOLogic.mk_prod (t, mk_tuple ts);
    2.97 +
    2.98 +(* builds the expression (%(v1,v2,..,vn). rhs) *)
    2.99 +fun lambda_tuple [] rhs = Term.lambda (Free("unit", HOLogic.unitT)) rhs
   2.100 +  | lambda_tuple (v::[]) rhs = Term.lambda v rhs
   2.101 +  | lambda_tuple (v::vs) rhs =
   2.102 +      HOLogic.mk_split (Term.lambda v (lambda_tuple vs rhs));
   2.103 +
   2.104 +
   2.105  (*************************************************************************)
   2.106  (************* fixed-point definitions and unfolding theorems ************)
   2.107  (*************************************************************************)
   2.108 @@ -166,40 +155,48 @@
   2.109    (spec : (Attrib.binding * term) list)
   2.110    (lthy : local_theory) =
   2.111    let
   2.112 +    val thy = ProofContext.theory_of lthy;
   2.113      val names = map (Binding.name_of o fst o fst) fixes;
   2.114      val all_names = space_implode "_" names;
   2.115      val (lhss,rhss) = ListPair.unzip (map (dest_eqs o snd) spec);
   2.116 -    val fixpoint = mk_fix (lambda_ctuple lhss (mk_ctuple rhss));
   2.117 +    val functional = lambda_tuple lhss (mk_tuple rhss);
   2.118 +    val fixpoint = mk_fix (mk_cabs functional);
   2.119      
   2.120 +    val cont_thm =
   2.121 +      Goal.prove lthy [] [] (mk_trp (mk_cont functional))
   2.122 +        (K (simp_tac (local_simpset_of lthy) 1));
   2.123 +
   2.124      fun one_def (l as Free(n,_)) r =
   2.125            let val b = Long_Name.base_name n
   2.126            in ((Binding.name (b^"_def"), []), r) end
   2.127        | one_def _ _ = fixrec_err "fixdefs: lhs not of correct form";
   2.128      fun defs [] _ = []
   2.129        | defs (l::[]) r = [one_def l r]
   2.130 -      | defs (l::ls) r = one_def l (mk_cfst r) :: defs ls (mk_csnd r);
   2.131 +      | defs (l::ls) r = one_def l (mk_fst r) :: defs ls (mk_snd r);
   2.132      val fixdefs = defs lhss fixpoint;
   2.133      val define_all = fold_map (LocalTheory.define Thm.definitionK);
   2.134      val (fixdef_thms : (term * (string * thm)) list, lthy') = lthy
   2.135        |> define_all (map (apfst fst) fixes ~~ fixdefs);
   2.136 -    fun cpair_equalI (thm1, thm2) = @{thm cpair_equalI} OF [thm1, thm2];
   2.137 -    val ctuple_fixdef_thm = foldr1 cpair_equalI (map (snd o snd) fixdef_thms);
   2.138 -    val ctuple_induct_thm = ctuple_fixdef_thm RS def_fix_ind;
   2.139 -    val ctuple_unfold_thm =
   2.140 -      Goal.prove lthy' [] [] (mk_trp (mk_ctuple lhss === mk_ctuple rhss))
   2.141 -        (fn _ => EVERY [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
   2.142 -                   simp_tac (local_simpset_of lthy') 1]);
   2.143 +    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
   2.144 +    val tuple_fixdef_thm = foldr1 pair_equalI (map (snd o snd) fixdef_thms);
   2.145 +    val P = Var (("P", 0), map Term.fastype_of lhss ---> HOLogic.boolT);
   2.146 +    val predicate = lambda_tuple lhss (list_comb (P, lhss));
   2.147 +    val tuple_induct_thm = (def_cont_fix_ind OF [tuple_fixdef_thm, cont_thm])
   2.148 +      |> Drule.instantiate' [] [SOME (Thm.cterm_of thy predicate)]
   2.149 +      |> LocalDefs.unfold lthy @{thms split_paired_all split_conv split_strict};
   2.150 +    val tuple_unfold_thm = (def_cont_fix_eq OF [tuple_fixdef_thm, cont_thm])
   2.151 +      |> LocalDefs.unfold lthy' @{thms split_conv};
   2.152      fun unfolds [] thm = []
   2.153        | unfolds (n::[]) thm = [(n^"_unfold", thm)]
   2.154        | unfolds (n::ns) thm = let
   2.155 -          val thmL = thm RS @{thm cpair_eqD1};
   2.156 -          val thmR = thm RS @{thm cpair_eqD2};
   2.157 +          val thmL = thm RS @{thm Pair_eqD1};
   2.158 +          val thmR = thm RS @{thm Pair_eqD2};
   2.159          in (n^"_unfold", thmL) :: unfolds ns thmR end;
   2.160 -    val unfold_thms = unfolds names ctuple_unfold_thm;
   2.161 +    val unfold_thms = unfolds names tuple_unfold_thm;
   2.162      fun mk_note (n, thm) = ((Binding.name n, []), [thm]);
   2.163      val (thmss, lthy'') = lthy'
   2.164        |> fold_map (LocalTheory.note Thm.theoremK o mk_note)
   2.165 -        ((all_names ^ "_induct", ctuple_induct_thm) :: unfold_thms);
   2.166 +        ((all_names ^ "_induct", tuple_induct_thm) :: unfold_thms);
   2.167    in
   2.168      (lthy'', names, fixdef_thms, map snd unfold_thms)
   2.169    end;