New features:
authorhuffman
Thu, 23 Jun 2005 21:27:23 +0200
changeset 16552 0774e9bcdb6c
parent 16551 7abf8a713613
child 16553 aa36d41e4263
New features: permissive option for fixrec to skip proofs of equations; side conditions for fixrec equations (for definedness); fixpat theorem names apply to entire group of theorems; improved error messages
src/HOLCF/fixrec_package.ML
--- a/src/HOLCF/fixrec_package.ML	Thu Jun 23 21:17:26 2005 +0200
+++ b/src/HOLCF/fixrec_package.ML	Thu Jun 23 21:27:23 2005 +0200
@@ -7,15 +7,23 @@
 
 signature FIXREC_PACKAGE =
 sig
-  val add_fixrec: ((string * Attrib.src list) * string) list list -> theory -> theory
-  val add_fixrec_i: ((string * theory attribute list) * term) list list -> theory -> theory
-  val add_fixpat: ((string * Attrib.src list) * string) list -> theory -> theory
-  val add_fixpat_i: ((string * theory attribute list) * term) list -> theory -> theory
+  val add_fixrec: bool -> ((string * Attrib.src list) * string) list list
+    -> theory -> theory
+  val add_fixrec_i: bool -> ((string * theory attribute list) * term) list list
+    -> theory -> theory
+  val add_fixpat: (string * Attrib.src list) * string list
+    -> theory -> theory
+  val add_fixpat_i: (string * theory attribute list) * term list
+    -> theory -> theory
 end;
 
 structure FixrecPackage: FIXREC_PACKAGE =
 struct
 
+fun fixrec_err s = error ("fixrec definition error:\n" ^ s);
+fun fixrec_eq_err sign s eq =
+  fixrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq));
+
 (* ->> is taken from holcf_logic.ML *)
 (* TODO: fix dependencies so we can import HOLCFLogic here *)
 infixr 6 ->>;
@@ -29,10 +37,12 @@
 val mk_trp = HOLogic.mk_Trueprop;
 
 (* splits a cterm into the right and lefthand sides of equality *)
+fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
+(*
 fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
   | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs)) = (lhs,rhs)
-  | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t);
-
+  | dest_eqs t = fixrec_err (Sign.string_of_term (sign_of (the_context())) t);
+*)
 (* similar to Thm.head_of, but for continuous application *)
 fun chead_of (Const("Cfun.Rep_CFun",_)$f$t) = chead_of f
   | chead_of u = u;
@@ -83,8 +93,8 @@
     val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss);
     
     fun one_def (l as Const(n,T)) r =
-          let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end
-      | one_def _ _ = sys_error "fixdefs: lhs not of correct form";
+          let val b = Sign.base_name n in (b, (b^"_def", l == r)) end
+      | one_def _ _ = fixrec_err "fixdefs: lhs not of correct form";
     fun defs [] _ = []
       | defs (l::[]) r = [one_def l r]
       | defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r);
@@ -146,7 +156,9 @@
         val k = lambda_ctuple vs rhs;
       in
         (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
-      end;
+      end
+  | Free(n,_) => fixrec_err ("expected constructor, found free variable " ^ quote n)
+  | _ => fixrec_err "pre_build: invalid pattern";
 
 (* builds a monadic term for matching a function definition pattern *)
 (* returns (name, arity, matcher) *)
@@ -157,15 +169,12 @@
   | Const("Cfun.Rep_CFun", _)$f$x =>
       let val (rhs', v, taken') = pre_build x rhs [] taken;
       in building f rhs' (v::vs) taken' end
-  | Const(_,_) => (pat, length vs, big_lambdas vs rhs)
-  | _ => sys_error "function is not declared as constant in theory";
+  | Const(name,_) => (name, length vs, big_lambdas vs rhs)
+  | _ => fixrec_err "function is not declared as constant in theory";
 
 fun match_eq eq = 
-  let
-    val (lhs,rhs) = dest_eqs eq;
-    val (Const(name,_), arity, term) =
-      building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []);
-  in (name, arity, term) end;
+  let val (lhs,rhs) = dest_eqs eq;
+  in building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []) end;
 
 (* returns the sum (using +++) of the terms in ms *)
 (* also applies "run" to the result! *)
@@ -173,9 +182,9 @@
   let
     fun unLAM 0 t = t
       | unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t
-      | unLAM _ _ = sys_error "FIXREC: internal error, not enough LAMs";
+      | unLAM _ _ = fixrec_err "fatbar: internal error, not enough LAMs";
     fun reLAM 0 t = t
-      | reLAM n t = reLAM (n-1) (%%:"Abs_CFun" $ Abs("",dummyT,t));
+      | reLAM n t = reLAM (n-1) (%%:"Cfun.Abs_CFun" $ Abs("",dummyT,t));
     fun mplus (x,y) = %%:"Fixrec.mplus"`x`y;
     val msum = foldr1 mplus (map (unLAM arity) ms);
   in
@@ -192,9 +201,9 @@
   let
     val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
     val cname = if forall (fn x => n=x) names then n
-          else sys_error "FIXREC: all equations must define the same function";
+          else fixrec_err "all equations in block must define the same function";
     val arity = if forall (fn x => a=x) arities then a
-          else sys_error "FIXREC: all equations must have the same arity";
+          else fixrec_err "all equations in block must have the same arity";
     val rhs = fatbar arity mats;
   in
     mk_trp (%%:cname === rhs)
@@ -204,45 +213,52 @@
 (********************** Proving associated theorems **********************)
 (*************************************************************************)
 
-fun prove_simp thy unfold_thm t =
+(* proves a block of pattern matching equations as theorems, using unfold *)
+fun make_simps thy (unfold_thm, eqns) =
   let
-    val ss = simpset_of thy;
-    val ct = cterm_of thy t;
-    val thm = prove_goalw_cterm [] ct
-      (fn _ => [rtac (unfold_thm RS ssubst_lhs) 1, simp_tac ss 1]);
-  in thm end;
-
-(* this proves that each equation is a theorem *)
-fun prove_simps thy (unfold_thm,ts) = map (prove_simp thy unfold_thm) ts;
-
-(* proves the pattern matching equations as theorems, using unfold *)
-fun make_simps cnames unfold_thms namess attss tss thy = 
-  let
-    val thm_names = map (fn name => name^"_simps") cnames;
-    val rew_thmss = map (prove_simps thy) (unfold_thms ~~ tss);
-    val thms = (List.concat namess ~~ List.concat rew_thmss) ~~ List.concat attss;
-    val (thy',_) = PureThy.add_thms thms thy;
-    val thmss = thm_names ~~ rew_thmss;
-    val simp_attribute = rpair [Simplifier.simp_add_global];
+    fun tacsf prems =
+      [rtac (unfold_thm RS ssubst_lhs) 1, simp_tac (simpset_of thy addsimps prems) 1];
+    fun prove_term t = prove_goalw_cterm [] (cterm_of thy t) tacsf;
+    fun prove_eqn ((name, eqn_t), atts) = ((name, prove_term eqn_t), atts);
   in
-    (#1 o PureThy.add_thmss (map simp_attribute thmss)) thy'
+    map prove_eqn eqns
   end;
 
 (*************************************************************************)
 (************************* Main fixrec function **************************)
 (*************************************************************************)
 
-(* this calls the main processing function and then returns the new state *)
-fun gen_add_fixrec prep_prop prep_attrib blocks thy =
+fun gen_add_fixrec prep_prop prep_attrib strict blocks thy =
   let
-    fun split_list2 xss = split_list (map split_list xss);
-    val ((namess, srcsss), strss) = apfst split_list2 (split_list2 blocks);
-    val attss = map (map (map (prep_attrib thy))) srcsss;
-    val tss = map (map (prep_prop thy)) strss;
-    val ts' = map (infer thy o compile_pats) tss;
-    val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs ts' thy;
+    val eqns = List.concat blocks;
+    val lengths = map length blocks;
+    
+    val sign = sign_of thy;
+    val ((names, srcss), strings) = apfst split_list (split_list eqns);
+    val atts = map (map (prep_attrib thy)) srcss;
+    val eqn_ts = map (prep_prop thy) strings;
+    val rec_ts = map (fn eq => chead_of (fst (dest_eqs (Logic.strip_imp_concl eq)))
+      handle TERM _ => fixrec_eq_err sign "not a proper equation" eq) eqn_ts;
+    val (_, eqn_ts') = InductivePackage.unify_consts sign rec_ts eqn_ts;
+    
+    fun unconcat [] _ = []
+      | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));
+    val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts');
+    val compiled_ts = map (infer sign o compile_pats) pattern_blocks;
+    val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy;
   in
-    make_simps cnames unfold_thms namess attss tss thy'
+    if strict then let (* only prove simp rules if strict = true *)
+      val eqn_blocks = unconcat lengths ((names ~~ eqn_ts') ~~ atts);
+      val simps = List.concat (map (make_simps thy') (unfold_thms ~~ eqn_blocks));
+      val (thy'', simp_thms) = PureThy.add_thms simps thy';
+      
+      val simp_names = map (fn name => name^"_simps") cnames;
+      val simp_attribute = rpair [Simplifier.simp_add_global];
+      val simps' = map simp_attribute (simp_names ~~ unconcat lengths simp_thms);
+    in
+      (#1 o PureThy.add_thmss simps') thy''
+    end
+    else thy'
   end;
 
 val add_fixrec = gen_add_fixrec Sign.read_prop Attrib.global_attribute;
@@ -253,25 +269,25 @@
 (******************************** Fixpat *********************************)
 (*************************************************************************)
 
-fun fix_pat prep_term thy pat = 
+fun fix_pat thy t = 
   let
-    val t = prep_term thy pat;
     val T = fastype_of t;
     val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T));
     val cname = case chead_of t of Const(c,_) => c | _ =>
-              sys_error "FIXPAT: function is not declared as constant in theory";
+              fixrec_err "function is not declared as constant in theory";
     val unfold_thm = PureThy.get_thm thy (Name (cname^"_unfold"));
-    val rew = prove_goalw_cterm [] (cterm_of thy eq)
+    val simp = prove_goalw_cterm [] (cterm_of thy eq)
           (fn _ => [stac unfold_thm 1, simp_tac (simpset_of thy) 1]);
-  in rew end;
+  in simp end;
 
-fun gen_add_fixpat prep_term prep_attrib pats thy =
+fun gen_add_fixpat prep_term prep_attrib ((name, srcs), strings) thy =
   let
-    val ((names, srcss), strings) = apfst ListPair.unzip (ListPair.unzip pats);
-    val atts = map (map (prep_attrib thy)) srcss;
-    val simps = map (fix_pat prep_term thy) strings;
-    val (thy', _) = PureThy.add_thms ((names ~~ simps) ~~ atts) thy;
-  in thy' end;
+    val atts = map (prep_attrib thy) srcs;
+    val ts = map (prep_term thy) strings;
+    val simps = map (fix_pat thy) ts;
+  in
+    (#1 o PureThy.add_thmss [((name, simps), atts)]) thy
+  end;
 
 val add_fixpat = gen_add_fixpat Sign.read_term Attrib.global_attribute;
 val add_fixpat_i = gen_add_fixpat Sign.cert_term (K I);
@@ -283,25 +299,27 @@
 
 local structure P = OuterParse and K = OuterSyntax.Keyword in
 
-val fixrec_decl = P.and_list1 (Scan.repeat1 (P.opt_thm_name ":" -- P.prop));
+val fixrec_eqn = P.opt_thm_name ":" -- P.prop;
+
+val fixrec_strict =
+  Scan.optional (P.$$$ "(" -- P.!!! (P.$$$ "permissive" -- P.$$$ ")") >> K false) true;
+
+val fixrec_decl = fixrec_strict -- P.and_list1 (Scan.repeat1 fixrec_eqn);
 
 (* this builds a parser for a new keyword, fixrec, whose functionality 
 is defined by add_fixrec *)
 val fixrecP =
   OuterSyntax.command "fixrec" "define recursive functions (HOLCF)" K.thy_decl
-    (fixrec_decl >> (Toplevel.theory o add_fixrec));
-
-(* this adds the parser for fixrec to the syntax *)
-val _ = OuterSyntax.add_parsers [fixrecP];
+    (fixrec_decl >> (Toplevel.theory o uncurry add_fixrec));
 
 (* fixpat parser *)
-val fixpat_decl = Scan.repeat1 (P.opt_thm_name ":" -- P.prop);
+val fixpat_decl = P.opt_thm_name ":" -- Scan.repeat1 P.prop;
 
 val fixpatP =
   OuterSyntax.command "fixpat" "define rewrites for fixrec functions" K.thy_decl
     (fixpat_decl >> (Toplevel.theory o add_fixpat));
 
-val _ = OuterSyntax.add_parsers [fixpatP];
+val _ = OuterSyntax.add_parsers [fixrecP, fixpatP];
 
 end; (* local structure *)