src/HOL/Tools/function_package/mutual.ML
changeset 20654 d80502f0d701
parent 20534 b147d0c13f6e
child 20797 c1f0bc7e7d80
--- a/src/HOL/Tools/function_package/mutual.ML	Thu Sep 21 03:17:51 2006 +0200
+++ b/src/HOL/Tools/function_package/mutual.ML	Thu Sep 21 12:22:05 2006 +0200
@@ -12,6 +12,7 @@
   
   val prepare_fundef_mutual : ((string * typ) * mixfix) list 
                               -> term list 
+                              -> string (* default, unparsed term *)
                               -> local_theory 
                               -> ((FundefCommon.mutual_info * string * FundefCommon.prep_result) * local_theory)
 
@@ -42,7 +43,7 @@
 
 fun check_head fs t =
     if (case t of 
-          (Free (n, _)) => n mem fs
+          (Free (n, _)) => n mem (map fst fs)
         | _ => false)
     then dest_Free t
     else raise ERROR "Head symbol of every left hand side must be the new function." (* FIXME: Output the equation here *)
@@ -54,7 +55,7 @@
 
 
 (* Builds a curried clause description in abstracted form *)
-fun split_def fnames geq =
+fun split_def fs geq arities =
     let
       val (qs, imp) = open_all_all geq
 
@@ -63,7 +64,7 @@
 
       val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
       val (fc, args) = strip_comb f_args
-      val f as (fname, _) = check_head fnames fc
+      val f as (fname, _) = check_head fs fc
 
       fun add_bvs t is = add_loose_bnos (t, 0, is)
       val rhs_only = (add_bvs rhs [] \\ fold add_bvs args [])
@@ -71,8 +72,16 @@
                         |> map (fst o nth (rev qs))
       val _ = if null rhs_only then () 
 	      else raise ERROR "Variables occur on right hand side only." (* FIXME: Output vars *)
+
+      val k = length args
+
+      val arities' = case Symtab.lookup arities fname of
+                   NONE => Symtab.update (fname, k) arities
+                 | SOME i => if (i <> k) 
+                             then raise ERROR ("Function " ^ fname ^ " has different numbers of arguments in different equations")
+                             else arities
     in
-	((f, length args), (fname, qs, gs, args, rhs))
+	((fname, qs, gs, args, rhs), arities')
     end
 
 fun get_part fname =
@@ -89,45 +98,38 @@
     end;
 
 
-fun analyze_eqs ctxt fnames eqs =
+fun analyze_eqs ctxt fs eqs =
     let
-      (* FIXME: Add check for number of arguments
-	fun all_equal ((x as ((n:string,T), k:int))::xs) = if forall (fn ((n',_),k') => n = n' andalso k = k') xs then x
-							   else raise ERROR ("All equations in a block must describe the same "
-									     ^ "function and have the same number of arguments.")
-       *)
-								      
-        val (consts, fqgars) = split_list (map (split_def fnames) eqs)
+        val fnames = map fst fs 
+        val (fqgars, arities) = fold_map (split_def fs) eqs Symtab.empty
 
-        val different_consts = distinct (eq_fst (eq_fst eq_str)) consts
-	val cnames = map (fst o fst) different_consts
-
-	val check_rcs = exists_subterm (fn Free (n, _) => if n mem cnames 
+	val check_rcs = exists_subterm (fn Free (n, _) => if n mem fnames
 						          then raise ERROR "Recursive Calls not allowed in premises." else false
                                          | _ => false)
                         
 	val _ = forall (fn (_, _, gs, _, _) => forall check_rcs gs) fqgars
 
-	fun curried_types ((_,T), k) =
+	fun curried_types (fname, fT) =
 	    let
-		val (caTs, uaTs) = chop k (binder_types T)
+              val k = the_default 1 (Symtab.lookup arities fname)
+	      val (caTs, uaTs) = chop k (binder_types fT)
 	    in 
-		(caTs, uaTs ---> body_type T)
+		(caTs, uaTs ---> body_type fT)
 	    end
 
-	val (caTss, resultTs) = split_list (map curried_types different_consts)
+	val (caTss, resultTs) = split_list (map curried_types fs)
 	val argTs = map (foldr1 HOLogic.mk_prodT) caTss
 
 	val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs
 	val (ST, streeA, pthsA) = SumTools.mk_tree argTs
 
-	val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name cnames)
+	val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name fnames)
 	val fsum_type = ST --> RST
 
         val ([fsum_var_name], _) = Variable.add_fixes [ def_name ^ "_sum" ] ctxt
         val fsum_var = (fsum_var_name, fsum_type)
 
-	fun define (fvar as (n, T), _) caTs pthA pthR = 
+	fun define (fvar as (n, T)) caTs pthA pthR = 
 	    let 
 		val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *)
 
@@ -139,7 +141,7 @@
 		(MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew)
 	    end
 
-	val (parts, rews) = split_list (map4 define different_consts caTss pthsA pthsR)
+	val (parts, rews) = split_list (map4 define fs caTss pthsA pthsR)
 
         fun convert_eqs (f, qs, gs, args, rhs) =
             let
@@ -181,17 +183,13 @@
     end
 
 
-
-  
-
-
-fun prepare_fundef_mutual fixes eqss lthy =
+fun prepare_fundef_mutual fixes eqss default lthy =
     let 
-	val mutual = analyze_eqs lthy (map (fst o fst) fixes) eqss
+	val mutual = analyze_eqs lthy (map fst fixes) eqss
 	val Mutual {defname, fsum_var=(n, T), qglrs, ...} = mutual
 
         val (prep_result, fsum, lthy') =
-            FundefPrep.prepare_fundef defname (n, T, NoSyn) qglrs lthy
+            FundefPrep.prepare_fundef defname (n, T, NoSyn) qglrs default lthy
 
         val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
     in