Some changes to allow mutually recursive, overloaded functions with same name.
authorberghofe
Fri, 08 Jul 2005 11:57:15 +0200
changeset 16765 b8b1f310877f
parent 16764 ca81a99c5bc1
child 16766 ea667a5426fe
Some changes to allow mutually recursive, overloaded functions with same name.
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Fri Jul 08 11:39:59 2005 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Fri Jul 08 11:57:15 2005 +0200
@@ -43,7 +43,7 @@
 	else raise RecError "illegal schematic variable(s)";
 
     val (recfun, args) = strip_comb lhs;
-    val (fname, _) = dest_Const recfun handle TERM _ => 
+    val fnameT = dest_Const recfun handle TERM _ => 
       raise RecError "function is not declared as constant in theory";
 
     val (ls', rest)  = take_prefix is_Free args;
@@ -72,9 +72,9 @@
      (check_vars "repeated variable names in pattern: " (duplicates lfrees);
       check_vars "extra variables on rhs: "
         (map dest_Free (term_frees rhs) \\ lfrees);
-      case assoc (rec_fns, fname) of
+      case assoc (rec_fns, fnameT) of
         NONE =>
-          (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
+          (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
       | SOME (_, rpos', eqns) =>
           if isSome (assoc (eqns, cname)) then
             raise RecError "constructor already occurred as pattern"
@@ -82,13 +82,13 @@
             raise RecError "position of recursive argument inconsistent"
           else
             overwrite (rec_fns, 
-		       (fname, 
+		       (fnameT, 
 			(tname, rpos,
 			 (cname, (ls, cargs, rs, rhs, eq))::eqns))))
   end
   handle RecError s => primrec_eq_err sign s eq;
 
-fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) =
+fun process_fun sign descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
   let
     val (_, (tname, _, constrs)) = List.nth (descr, i);
 
@@ -101,10 +101,10 @@
       | subst subs (fs, t as (_ $ _)) =
           let val (f, ts) = strip_comb t;
           in
-            if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then
+            if is_Const f andalso dest_Const f mem map fst rec_eqns then
               let
-                val (fname', _) = dest_Const f;
-                val (_, rpos, _) = valOf (assoc (rec_eqns, fname'));
+                val fnameT' as (fname', _) = dest_Const f;
+                val (_, rpos, _) = valOf (assoc (rec_eqns, fnameT'));
                 val ls = Library.take (rpos, ts);
                 val rest = Library.drop (rpos, ts);
                 val (x', rs) = (hd rest, tl rest)
@@ -120,7 +120,7 @@
                   | SOME (i', y) =>
                       let
                         val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
-                        val fs'' = process_fun sign descr rec_eqns ((i', fname'), fs')
+                        val fs'' = process_fun sign descr rec_eqns ((i', fnameT'), fs')
                       in (fs'', list_comb (y, ts'))
                       end)
               end
@@ -133,41 +133,41 @@
 
     (* translate rec equations into function arguments suitable for rec comb *)
 
-    fun trans eqns ((cname, cargs), (fnames', fnss', fns)) =
+    fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) =
       (case assoc (eqns, cname) of
           NONE => (warning ("No equation for constructor " ^ quote cname ^
             "\nin definition of function " ^ quote fname);
-              (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
+              (fnameTs', fnss', (Const ("arbitrary", dummyT))::fns))
         | SOME (ls, cargs', rs, rhs, eq) =>
             let
               val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
               val rargs = map fst recs;
               val subs = map (rpair dummyT o fst) 
 		             (rev (rename_wrt_term rhs rargs));
-              val ((fnames'', fnss''), rhs') = 
+              val ((fnameTs'', fnss''), rhs') = 
 		  (subst (map (fn ((x, y), z) =>
 			       (Free x, (body_index y, Free z)))
 			  (recs ~~ subs))
-		   ((fnames', fnss'), rhs))
+		   ((fnameTs', fnss'), rhs))
                   handle RecError s => primrec_eq_err sign s eq
-            in (fnames'', fnss'', 
+            in (fnameTs'', fnss'', 
 		(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
             end)
 
-  in (case assoc (fnames, i) of
+  in (case assoc (fnameTs, i) of
       NONE =>
-        if exists (equal fname o snd) fnames then
+        if exists (equal fnameT o snd) fnameTs then
           raise RecError ("inconsistent functions for datatype " ^ quote tname)
         else
           let
-            val (_, _, eqns) = valOf (assoc (rec_eqns, fname));
-            val (fnames', fnss', fns) = foldr (trans eqns)
-              ((i, fname)::fnames, fnss, []) constrs
+            val (_, _, eqns) = valOf (assoc (rec_eqns, fnameT));
+            val (fnameTs', fnss', fns) = foldr (trans eqns)
+              ((i, fnameT)::fnameTs, fnss, []) constrs
           in
-            (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
+            (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
           end
-    | SOME fname' =>
-        if fname = fname' then (fnames, fnss)
+    | SOME fnameT' =>
+        if fnameT = fnameT' then (fnameTs, fnss)
         else raise RecError ("inconsistent functions for datatype " ^ quote tname))
   end;
 
@@ -241,20 +241,21 @@
 	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
 			 "\nare not mutually recursive")
 	else snd (hd dts);
-    val (fnames, fnss) = foldr (process_fun sg descr rec_eqns)
+    val (fnameTs, fnss) = foldr (process_fun sg descr rec_eqns)
 	                       ([], []) main_fns;
     val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
     val defs' = map (make_def sg fs) defs;
-    val names1 = map snd fnames;
-    val names2 = map fst rec_eqns;
+    val nameTs1 = map snd fnameTs;
+    val nameTs2 = map fst rec_eqns;
     val primrec_name =
       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
     val (thy', defs_thms') = thy |> Theory.add_path primrec_name |>
-      (if eq_set (names1, names2) then (PureThy.add_defs_i false o map Thm.no_attributes) defs'
-       else primrec_err ("functions " ^ commas_quote names2 ^
+      (if eq_set (nameTs1, nameTs2) then (PureThy.add_defs_i false o map Thm.no_attributes) defs'
+       else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^
          "\nare not mutually recursive"));
     val rewrites = (map mk_meta_eq rec_rewrites) @ defs_thms';
-    val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names1 ^ " ...");
+    val _ = message ("Proving equations for primrec function(s) " ^
+      commas_quote (map fst nameTs1) ^ " ...");
     val simps = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (Theory.sign_of thy') t)
         (fn _ => [rtac refl 1])) eqns;
     val (thy'', simps') = PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts) thy';