fixrec package now handles mutually-recursive definitions
authorhuffman
Wed Jun 15 20:50:38 2005 +0200 (2005-06-15)
changeset 1640157c35ede00b9
parent 16400 f2ab5797bbd0
child 16402 36f41d5e3b3e
fixrec package now handles mutually-recursive definitions
src/HOLCF/Fixrec.thy
src/HOLCF/fixrec_package.ML
     1.1 --- a/src/HOLCF/Fixrec.thy	Wed Jun 15 14:59:25 2005 +0200
     1.2 +++ b/src/HOLCF/Fixrec.thy	Wed Jun 15 20:50:38 2005 +0200
     1.3 @@ -134,6 +134,27 @@
     1.4    "match_up\<cdot>\<bottom> = \<bottom>"
     1.5  by (simp_all add: match_up_def)
     1.6  
     1.7 +subsection {* Mutual recursion *}
     1.8 +
     1.9 +text {*
    1.10 +  The following rules are used to prove unfolding theorems from
    1.11 +  fixed-point definitions of mutually recursive functions.
    1.12 +*}
    1.13 +
    1.14 +lemma cpair_equalI: "\<lbrakk>x \<equiv> cfst\<cdot>p; y \<equiv> csnd\<cdot>p\<rbrakk> \<Longrightarrow> <x,y> \<equiv> p"
    1.15 +by (simp add: surjective_pairing_Cprod2)
    1.16 +
    1.17 +lemma cpair_eqD1: "<x,y> = <x',y'> \<Longrightarrow> x = x'"
    1.18 +by simp
    1.19 +
    1.20 +lemma cpair_eqD2: "<x,y> = <x',y'> \<Longrightarrow> y = y'"
    1.21 +by simp
    1.22 +
    1.23 +ML {*
    1.24 +val cpair_equalI = thm "cpair_equalI";
    1.25 +val cpair_eqD1 = thm "cpair_eqD1";
    1.26 +val cpair_eqD2 = thm "cpair_eqD2";
    1.27 +*}
    1.28  
    1.29  subsection {* Intitializing the fixrec package *}
    1.30  
     2.1 --- a/src/HOLCF/fixrec_package.ML	Wed Jun 15 14:59:25 2005 +0200
     2.2 +++ b/src/HOLCF/fixrec_package.ML	Wed Jun 15 20:50:38 2005 +0200
     2.3 @@ -7,7 +7,7 @@
     2.4  
     2.5  signature FIXREC_PACKAGE =
     2.6  sig
     2.7 -  val add_fixrec: string list -> theory -> theory
     2.8 +  val add_fixrec: string list list -> theory -> theory
     2.9    val add_fixpat: string * string -> theory -> theory
    2.10  end;
    2.11  
    2.12 @@ -27,6 +27,8 @@
    2.13  		   ("o"::"p"::" "::rest) => implode rest
    2.14  		   | _ => con;
    2.15  
    2.16 +val mk_trp = HOLogic.mk_Trueprop;
    2.17 +
    2.18  (* splits a cterm into the right and lefthand sides of equality *)
    2.19  fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
    2.20    | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs))    = (lhs,rhs)
    2.21 @@ -37,57 +39,92 @@
    2.22  fun %%: s = Const(s,dummyT);
    2.23  infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
    2.24  infix 1 ===; fun S === T = %%:"op =" $ S $ T;
    2.25 -infix 9 `  ; fun f ` x = %%:"Rep_CFun" $ f $ x;
    2.26 +infix 9 `  ; fun f ` x = %%:"Cfun.Rep_CFun" $ f $ x;
    2.27  
    2.28  (* infers the type of a term *)
    2.29 -fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
    2.30 +(* similar to Theory.inferT_axm, but allows any type *)
    2.31 +fun infer sg t =
    2.32 +  fst (Sign.infer_types (Sign.pp sg) sg (K NONE) (K NONE) [] true ([t],dummyT));
    2.33 +
    2.34 +(* The next few functions build continuous lambda abstractions *)
    2.35 +
    2.36 +(* Similar to Term.lambda, but allows abstraction over constants *)
    2.37 +fun lambda' (v as Free (x, T)) t = Abs (x, T, abstract_over (v, t))
    2.38 +  | lambda' (v as Var ((x, _), T)) t = Abs (x, T, abstract_over (v, t))
    2.39 +  | lambda' (v as Const (x, T)) t = Abs (Sign.base_name x, T, abstract_over (v, t))
    2.40 +  | lambda' v t = raise TERM ("lambda'", [v, t]);
    2.41 +
    2.42 +(* builds the expression (LAM v. rhs) *)
    2.43 +fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda' v rhs);
    2.44 +
    2.45 +(* builds the expression (LAM v1 v2 .. vn. rhs) *)
    2.46 +fun big_lambdas [] rhs = rhs
    2.47 +  | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
    2.48 +
    2.49 +(* builds the expression (LAM <v1,v2,..,vn>. rhs) *)
    2.50 +fun lambda_ctuple [] rhs = big_lambda (%:"unit") rhs
    2.51 +  | lambda_ctuple (v::[]) rhs = big_lambda v rhs
    2.52 +  | lambda_ctuple (v::vs) rhs =
    2.53 +      %%:"Cprod.csplit"`(big_lambda v (lambda_ctuple vs rhs));
    2.54 +
    2.55 +(* builds the expression <v1,v2,..,vn> *)
    2.56 +fun mk_ctuple [] = %%:"UU"
    2.57 +|   mk_ctuple (t::[]) = t
    2.58 +|   mk_ctuple (t::ts) = %%:"Cprod.cpair"`t`(mk_ctuple ts);
    2.59  
    2.60  (*************************************************************************)
    2.61 -(************ fixed-point definitions and unfolding theorems *************)
    2.62 +(************************ fixed-point definitions ************************)
    2.63  (*************************************************************************)
    2.64  
    2.65 -fun func1 (lhs as Const(name,T), rhs) =
    2.66 +fun add_fixdefs eqs thy =
    2.67    let
    2.68 -    val basename = Sign.base_name name;
    2.69 -    val funcT = T ->> T;
    2.70 -    val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $
    2.71 -          Abs (basename, T, abstract_over (lhs,rhs));
    2.72 -    val fix_type = funcT ->> T;
    2.73 -    val fix_const = Const ("Fix.fix", fix_type);
    2.74 -    val func_type = fix_type --> funcT --> T;
    2.75 -    val rhs' = Const ("Cfun.Rep_CFun",func_type)$fix_const$functional;
    2.76 +    val (lhss,rhss) = ListPair.unzip (map dest_eqs eqs);
    2.77 +    val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss);
    2.78 +    
    2.79 +    fun one_def (l as Const(n,T)) r =
    2.80 +          let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end
    2.81 +      | one_def _ _ = sys_error "fixdefs: lhs not of correct form";
    2.82 +    fun defs [] _ = []
    2.83 +      | defs (l::[]) r = [one_def l r]
    2.84 +      | defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r);
    2.85 +    val (names, pre_fixdefs) = ListPair.unzip (defs lhss fixpoint);
    2.86 +    
    2.87 +    val fixdefs = map (inferT_axm (sign_of thy)) pre_fixdefs;
    2.88 +    val (thy', fixdef_thms) =
    2.89 +      PureThy.add_defs_i false (map Thm.no_attributes fixdefs) thy;
    2.90 +    val ctuple_fixdef_thm = foldr1 (fn (x,y) => cpair_equalI OF [x,y]) fixdef_thms;
    2.91 +    
    2.92 +    fun mk_cterm t = let val sg' = sign_of thy' in cterm_of sg' (infer sg' t) end;
    2.93 +    val ctuple_unfold_ct = mk_cterm (mk_trp (mk_ctuple lhss === mk_ctuple rhss));
    2.94 +    val ctuple_unfold_thm = prove_goalw_cterm [] ctuple_unfold_ct
    2.95 +          (fn _ => [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
    2.96 +                    simp_tac (simpset_of thy') 1]);
    2.97 +    val ctuple_induct_thm =
    2.98 +          (space_implode "_" names ^ "_induct", [ctuple_fixdef_thm RS def_fix_ind]);
    2.99 +    
   2.100 +    fun unfolds [] thm = []
   2.101 +      | unfolds (n::[]) thm = [(n^"_unfold", [thm])]
   2.102 +      | unfolds (n::ns) thm = let
   2.103 +          val thmL = thm RS cpair_eqD1;
   2.104 +          val thmR = thm RS cpair_eqD2;
   2.105 +        in (n^"_unfold", [thmL]) :: unfolds ns thmR end;
   2.106 +    val unfold_thmss = unfolds names ctuple_unfold_thm;
   2.107 +    val thmss = ctuple_induct_thm :: unfold_thmss;
   2.108 +    val (thy'', _) = PureThy.add_thmss (map Thm.no_attributes thmss) thy';
   2.109    in
   2.110 -    (name, (basename^"_fixdef", equals T $ lhs $ rhs'))
   2.111 -  end
   2.112 -  | func1 t = sys_error "func1: not of correct form";
   2.113 +    (thy'', names, fixdef_thms, List.concat (map snd unfold_thmss))
   2.114 +  end;
   2.115  
   2.116  (*************************************************************************)
   2.117  (*********** monadic notation and pattern matching compilation ***********)
   2.118  (*************************************************************************)
   2.119  
   2.120 -(* these 3 functions strip off parameters and destruct constructors *)
   2.121 -(*
   2.122 -fun strip_cpair (Const("Cfun.Rep_CFun",_) $
   2.123 -      (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
   2.124 -        b :: strip_cpair r
   2.125 -  | strip_cpair c = [c];
   2.126 -*)
   2.127 -fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
   2.128 -
   2.129 -fun big_lambdas [] rhs = rhs
   2.130 -  | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
   2.131 +fun add_names (Const(a,_), bs) = Sign.base_name a ins_string bs
   2.132 +  | add_names (Free(a,_) , bs) = a ins_string bs
   2.133 +  | add_names (f $ u     , bs) = add_names (f, add_names(u, bs))
   2.134 +  | add_names (Abs(a,_,t), bs) = add_names (t, a ins_string bs)
   2.135 +  | add_names (_         , bs) = bs;
   2.136  
   2.137 -(* builds a big lamdba expression with a tuple *)
   2.138 -fun lambda_tuple [] rhs = big_lambda (%:"unit") rhs
   2.139 -  | lambda_tuple [v] rhs = big_lambda v rhs
   2.140 -  | lambda_tuple (v::vs) rhs =
   2.141 -      %%:"Cprod.csplit"`(big_lambda v (lambda_tuple vs rhs));
   2.142 -
   2.143 -fun add_names (Const(a,_), bs) = NameSpace.base a ins_string bs
   2.144 -  | add_names (Free(a,_), bs) = a ins_string bs
   2.145 -  | add_names (f$u, bs) = add_names (f, add_names(u, bs))
   2.146 -  | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs)
   2.147 -  | add_names (_, bs) = bs;
   2.148  fun add_terms ts xs = foldr add_names xs ts;
   2.149  
   2.150  (* builds a monadic term for matching a constructor pattern *)
   2.151 @@ -104,8 +141,8 @@
   2.152          fun result_type (Type("Cfun.->",[_,T])) (x::xs) = result_type T xs
   2.153            | result_type T _ = T;
   2.154          val v = Free(n, result_type T vs);
   2.155 -        val m = "match_"^(extern_name(NameSpace.base c));
   2.156 -        val k = lambda_tuple vs rhs;
   2.157 +        val m = "match_"^(extern_name(Sign.base_name c));
   2.158 +        val k = lambda_ctuple vs rhs;
   2.159        in
   2.160          (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
   2.161        end;
   2.162 @@ -159,14 +196,14 @@
   2.163            else sys_error "FIXREC: all equations must have the same arity";
   2.164      val rhs = fatbar arity mats;
   2.165    in
   2.166 -    HOLogic.mk_Trueprop (%%:cname === rhs)
   2.167 +    mk_trp (%%:cname === rhs)
   2.168    end;
   2.169  
   2.170  (*************************************************************************)
   2.171  (********************** Proving associated theorems **********************)
   2.172  (*************************************************************************)
   2.173  
   2.174 -fun prove_thm thy unfold_thm ct =
   2.175 +fun prove_rew thy unfold_thm ct =
   2.176    let
   2.177      val ss = simpset_of thy;
   2.178      val thm = prove_goalw_cterm [] ct
   2.179 @@ -175,18 +212,24 @@
   2.180    in thm end;
   2.181  
   2.182  (* this proves that each equation is a theorem *)
   2.183 -fun prove_list thy unfold_thm [] = []
   2.184 -  | prove_list thy unfold_thm (x::xs) =
   2.185 -      prove_thm thy unfold_thm x :: prove_list thy unfold_thm xs;
   2.186 +fun prove_rews thy (unfold_thm,cts) = map (prove_rew thy unfold_thm) cts;
   2.187 +
   2.188 +(* proves the pattern matching equations as theorems, using unfold *)
   2.189 +fun make_simps names unfold_thms ctss thy = 
   2.190 +  let
   2.191 +    val thm_names = map (fn name => name^"_rews") names;
   2.192 +    val rew_thmss = ListPair.map (prove_rews thy) (unfold_thms, ctss);
   2.193 +    val thmss = ListPair.zip (thm_names, rew_thmss);
   2.194 +  in
   2.195 +    (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy
   2.196 +  end;
   2.197  
   2.198  (* this proves the def without fix is a theorem, this uses the fixpoint def *)
   2.199 +(*
   2.200  fun make_simp name eqs ct fixdef_thm thy' = 
   2.201    let
   2.202 -    val basename = NameSpace.base name;
   2.203      val ss = simpset_of thy';
   2.204      val eq_thm = fixdef_thm RS fix_eq2;
   2.205 -    val unfold_thm = prove_goalw_cterm [] ct
   2.206 -      (fn _ => [(rtac (eq_thm RS trans) 1) THEN (simp_tac ss 1)]);
   2.207      val ind_thm = fixdef_thm RS def_fix_ind;
   2.208      val rew_thms = prove_list thy' unfold_thm eqs;
   2.209      val thmss =
   2.210 @@ -196,25 +239,21 @@
   2.211    in
   2.212      (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
   2.213    end;
   2.214 -
   2.215 +*)
   2.216  (*************************************************************************)
   2.217  (************************* Main fixrec function **************************)
   2.218  (*************************************************************************)
   2.219  
   2.220  (* this calls the main processing function and then returns the new state *)
   2.221 -fun add_fixrec strs thy =
   2.222 +fun add_fixrec strss thy =
   2.223    let
   2.224      val sg = sign_of thy;
   2.225 -    val cts = map (Thm.read_cterm sg o rpair propT) strs;
   2.226 -    val eqs = map term_of cts;
   2.227 -    val funcc = infer (compile_pats eqs) thy;
   2.228 -    val _ = print_cterm (cterm_of sg funcc);
   2.229 -    val (name', fixdef_name_term) = func1 (dest_eqs funcc);
   2.230 -    val (thy', [fixdef_thm]) =
   2.231 -      PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
   2.232 -    val ct = cterm_of (sign_of thy') funcc;
   2.233 +    val ctss = map (map (Thm.read_cterm sg o rpair propT)) strss;
   2.234 +    val tss = map (map term_of) ctss;
   2.235 +    val ts' = map (fn ts => infer sg (compile_pats ts)) tss;
   2.236 +    val (thy', names, fixdef_thms, unfold_thms) = add_fixdefs ts' thy;
   2.237    in
   2.238 -    make_simp name' cts ct fixdef_thm thy'
   2.239 +    make_simps names unfold_thms ctss thy'
   2.240    end;
   2.241  
   2.242  (*************************************************************************)
   2.243 @@ -226,7 +265,7 @@
   2.244      val sign = sign_of thy;
   2.245      val t = term_of (Thm.read_cterm sign (pat, dummyT));
   2.246      val T = fastype_of t;
   2.247 -    val eq = HOLogic.mk_Trueprop (HOLogic.eq_const T $ t $ Var (("x",0),T));
   2.248 +    val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T));
   2.249      fun head_const (Const ("Cfun.Rep_CFun",_) $ f $ t) = head_const f
   2.250        | head_const (Const (c,_)) = c
   2.251        | head_const _ = sys_error "FIXPAT: function is not declared as constant in theory";
   2.252 @@ -247,7 +286,7 @@
   2.253  
   2.254  local structure P = OuterParse and K = OuterSyntax.Keyword in
   2.255  
   2.256 -val fixrec_decl = (*P.and_list1*) (Scan.repeat1 P.prop);
   2.257 +val fixrec_decl = P.and_list1 (Scan.repeat1 P.prop);
   2.258  
   2.259  (* this builds a parser for a new keyword, fixrec, whose functionality 
   2.260  is defined by add_fixrec *)