--- a/src/HOL/Tools/function_package/mutual.ML Thu Sep 21 03:17:51 2006 +0200
+++ b/src/HOL/Tools/function_package/mutual.ML Thu Sep 21 12:22:05 2006 +0200
@@ -12,6 +12,7 @@
val prepare_fundef_mutual : ((string * typ) * mixfix) list
-> term list
+ -> string (* default, unparsed term *)
-> local_theory
-> ((FundefCommon.mutual_info * string * FundefCommon.prep_result) * local_theory)
@@ -42,7 +43,7 @@
fun check_head fs t =
if (case t of
- (Free (n, _)) => n mem fs
+ (Free (n, _)) => n mem (map fst fs)
| _ => false)
then dest_Free t
else raise ERROR "Head symbol of every left hand side must be the new function." (* FIXME: Output the equation here *)
@@ -54,7 +55,7 @@
(* Builds a curried clause description in abstracted form *)
-fun split_def fnames geq =
+fun split_def fs geq arities =
let
val (qs, imp) = open_all_all geq
@@ -63,7 +64,7 @@
val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
val (fc, args) = strip_comb f_args
- val f as (fname, _) = check_head fnames fc
+ val f as (fname, _) = check_head fs fc
fun add_bvs t is = add_loose_bnos (t, 0, is)
val rhs_only = (add_bvs rhs [] \\ fold add_bvs args [])
@@ -71,8 +72,16 @@
|> map (fst o nth (rev qs))
val _ = if null rhs_only then ()
else raise ERROR "Variables occur on right hand side only." (* FIXME: Output vars *)
+
+ val k = length args
+
+ val arities' = case Symtab.lookup arities fname of
+ NONE => Symtab.update (fname, k) arities
+ | SOME i => if (i <> k)
+ then raise ERROR ("Function " ^ fname ^ " has different numbers of arguments in different equations")
+ else arities
in
- ((f, length args), (fname, qs, gs, args, rhs))
+ ((fname, qs, gs, args, rhs), arities')
end
fun get_part fname =
@@ -89,45 +98,38 @@
end;
-fun analyze_eqs ctxt fnames eqs =
+fun analyze_eqs ctxt fs eqs =
let
- (* FIXME: Add check for number of arguments
- fun all_equal ((x as ((n:string,T), k:int))::xs) = if forall (fn ((n',_),k') => n = n' andalso k = k') xs then x
- else raise ERROR ("All equations in a block must describe the same "
- ^ "function and have the same number of arguments.")
- *)
-
- val (consts, fqgars) = split_list (map (split_def fnames) eqs)
+ val fnames = map fst fs
+ val (fqgars, arities) = fold_map (split_def fs) eqs Symtab.empty
- val different_consts = distinct (eq_fst (eq_fst eq_str)) consts
- val cnames = map (fst o fst) different_consts
-
- val check_rcs = exists_subterm (fn Free (n, _) => if n mem cnames
+ val check_rcs = exists_subterm (fn Free (n, _) => if n mem fnames
then raise ERROR "Recursive Calls not allowed in premises." else false
| _ => false)
val _ = forall (fn (_, _, gs, _, _) => forall check_rcs gs) fqgars
- fun curried_types ((_,T), k) =
+ fun curried_types (fname, fT) =
let
- val (caTs, uaTs) = chop k (binder_types T)
+ val k = the_default 1 (Symtab.lookup arities fname)
+ val (caTs, uaTs) = chop k (binder_types fT)
in
- (caTs, uaTs ---> body_type T)
+ (caTs, uaTs ---> body_type fT)
end
- val (caTss, resultTs) = split_list (map curried_types different_consts)
+ val (caTss, resultTs) = split_list (map curried_types fs)
val argTs = map (foldr1 HOLogic.mk_prodT) caTss
val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs
val (ST, streeA, pthsA) = SumTools.mk_tree argTs
- val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name cnames)
+ val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name fnames)
val fsum_type = ST --> RST
val ([fsum_var_name], _) = Variable.add_fixes [ def_name ^ "_sum" ] ctxt
val fsum_var = (fsum_var_name, fsum_type)
- fun define (fvar as (n, T), _) caTs pthA pthR =
+ fun define (fvar as (n, T)) caTs pthA pthR =
let
val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *)
@@ -139,7 +141,7 @@
(MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew)
end
- val (parts, rews) = split_list (map4 define different_consts caTss pthsA pthsR)
+ val (parts, rews) = split_list (map4 define fs caTss pthsA pthsR)
fun convert_eqs (f, qs, gs, args, rhs) =
let
@@ -181,17 +183,13 @@
end
-
-
-
-
-fun prepare_fundef_mutual fixes eqss lthy =
+fun prepare_fundef_mutual fixes eqss default lthy =
let
- val mutual = analyze_eqs lthy (map (fst o fst) fixes) eqss
+ val mutual = analyze_eqs lthy (map fst fixes) eqss
val Mutual {defname, fsum_var=(n, T), qglrs, ...} = mutual
val (prep_result, fsum, lthy') =
- FundefPrep.prepare_fundef defname (n, T, NoSyn) qglrs lthy
+ FundefPrep.prepare_fundef defname (n, T, NoSyn) qglrs default lthy
val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
in