- Now also supports arbitrarily branching datatypes.
authorberghofe
Fri, 16 Jul 1999 13:24:41 +0200
changeset 7016 df54b5365477
parent 7015 85be09eb136c
child 7017 e4e64a0b0b6b
- Now also supports arbitrarily branching datatypes. - Fixed bug (in some rare cases, recursive constants were inconsistently typed in different primrec equations).
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Fri Jul 16 12:14:04 1999 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Fri Jul 16 13:24:41 1999 +0200
@@ -40,11 +40,11 @@
     val (lhs, rhs) = 
 	if null (term_vars eq) then
 	    HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
-	      handle _ => raise RecError "not a proper equation"
+	      handle TERM _ => raise RecError "not a proper equation"
 	else raise RecError "illegal schematic variable(s)";
 
     val (recfun, args) = strip_comb lhs;
-    val (fname, _) = dest_Const recfun handle _ => 
+    val (fname, _) = dest_Const recfun handle TERM _ => 
       raise RecError "function is not declared as constant in theory";
 
     val (ls', rest)  = take_prefix is_Free args;
@@ -54,14 +54,14 @@
     val (constr, cargs') = if null middle then raise RecError "constructor missing"
       else strip_comb (hd middle);
     val (cname, T) = dest_Const constr
-      handle _ => raise RecError "ill-formed constructor";
-    val (tname, _) = dest_Type (body_type T) handle _ =>
+      handle TERM _ => raise RecError "ill-formed constructor";
+    val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
       raise RecError "cannot determine datatype associated with function"
 
     val (ls, cargs, rs) = (map dest_Free ls', 
 			   map dest_Free cargs', 
 			   map dest_Free rs')
-      handle _ => raise RecError "illegal argument in pattern";
+      handle TERM _ => raise RecError "illegal argument in pattern";
     val lfrees = ls @ rs @ cargs;
 
   in
@@ -106,9 +106,10 @@
                 val (_, rpos, _) = the (assoc (rec_eqns, fname'));
                 val ls = take (rpos, ts);
                 val rest = drop (rpos, ts);
-                val (x, rs) = (hd rest, tl rest)
-                  handle _ => raise RecError ("not enough arguments\
-                   \ in recursive application\nof function " ^ quote fname' ^ " on rhs")
+                val (x', rs) = (hd rest, tl rest)
+                  handle LIST _ => raise RecError ("not enough arguments\
+                   \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
+                val (x, xs) = strip_comb x'
               in 
                 (case assoc (subs, x) of
                     None =>
@@ -117,7 +118,7 @@
                       in (fs', list_comb (f, ts')) end
                   | Some (i', y) =>
                       let
-                        val (fs', ts') = foldl_map (subst subs) (fs, ls @ rs);
+                        val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
                         val fs'' = process_fun sign descr rec_eqns ((i', fname'), fs')
                       in (fs'', list_comb (y, ts'))
                       end)
@@ -138,13 +139,16 @@
               (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
         | Some (ls, cargs', rs, rhs, eq) =>
             let
+              fun rec_index (DtRec k) = k
+                | rec_index (DtType ("fun", [_, DtRec k])) = k;
+
               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
               val rargs = map fst recs;
               val subs = map (rpair dummyT o fst) 
 		             (rev (rename_wrt_term rhs rargs));
               val ((fnames'', fnss''), rhs') = 
 		  (subst (map (fn ((x, y), z) =>
-			       (Free x, (dest_DtRec y, Free z)))
+			       (Free x, (rec_index y, Free z)))
 			  (recs ~~ subs))
 		   ((fnames', fnss'), rhs))
                   handle RecError s => primrec_eq_err sign s eq
@@ -240,7 +244,7 @@
       (if eq_set (names1, names2) then Theory.add_defs_i defs'
        else primrec_err ("functions " ^ commas_quote names2 ^
          "\nare not mutually recursive"));
-    val rewrites = (map mk_meta_eq rec_rewrites) @ (map (get_axiom thy' o fst) defs');
+    val rewrites = o_def :: (map mk_meta_eq rec_rewrites) @ (map (get_axiom thy' o fst) defs');
     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names1 ^ " ...");
     val char_thms = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (Theory.sign_of thy') t)
         (fn _ => [rtac refl 1])) eqns;
@@ -253,10 +257,17 @@
   in (thy'', char_thms) end;
 
 
-fun read_eqn thy ((name, s), srcs) =
-  ((name, readtm (Theory.sign_of thy) propT s), map (Attrib.global_attribute thy) srcs);
-
-fun add_primrec alt_name eqns thy = add_primrec_i alt_name (map (read_eqn thy) eqns) thy;
+fun add_primrec alt_name eqns thy =
+  let
+    val ((names, strings), srcss) = apfst split_list (split_list eqns);
+    val atts = map (map (Attrib.global_attribute thy)) srcss;
+    val eqn_ts = map (readtm (Theory.sign_of thy) propT) strings;
+    val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq)))
+      handle TERM _ => raise RecError "not a proper equation") eqn_ts;
+    val (_, eqn_ts') = InductivePackage.unify_consts (sign_of thy) rec_ts eqn_ts
+  in
+    add_primrec_i alt_name (names ~~ eqn_ts' ~~ atts) thy
+  end;
 
 
 (* outer syntax *)