# HG changeset patch # User wenzelm # Date 1160177466 -7200 # Node ID 384c5bb713b28a4f570c16c67464538e4ff51ba8 # Parent 368b997ad67e5702517dfc2ce30a55b48cb48955 mk_partial_rules_mutual: expand result terms/thms; diff -r 368b997ad67e -r 384c5bb713b2 src/HOL/Tools/function_package/mutual.ML --- a/src/HOL/Tools/function_package/mutual.ML Sat Oct 07 01:31:05 2006 +0200 +++ b/src/HOL/Tools/function_package/mutual.ML Sat Oct 07 01:31:06 2006 +0200 @@ -2,22 +2,21 @@ ID: $Id$ Author: Alexander Krauss, TU Muenchen -A package for general recursive function definitions. +A package for general recursive function definitions. Tools for mutual recursive definitions. - *) -signature FUNDEF_MUTUAL = +signature FUNDEF_MUTUAL = sig - - val prepare_fundef_mutual : ((string * typ) * mixfix) list - -> term list + + val prepare_fundef_mutual : ((string * typ) * mixfix) list + -> term list -> string (* default, unparsed term *) - -> local_theory + -> local_theory -> ((FundefCommon.mutual_info * string * FundefCommon.prep_result) * local_theory) - val mk_partial_rules_mutual : Proof.context -> FundefCommon.mutual_info -> FundefCommon.prep_result -> thm -> + val mk_partial_rules_mutual : Proof.context -> FundefCommon.mutual_info -> FundefCommon.prep_result -> thm -> FundefCommon.fundef_mresult val sort_by_function : FundefCommon.mutual_info -> string list -> 'a list -> 'a list list @@ -25,7 +24,7 @@ end -structure FundefMutual: FUNDEF_MUTUAL = +structure FundefMutual: FUNDEF_MUTUAL = struct open FundefCommon @@ -36,10 +35,10 @@ -fun mutual_induct_Pnames n = +fun mutual_induct_Pnames n = if n < 5 then fst (chop n ["P","Q","R","S"]) else map (fn i => "P" ^ string_of_int i) (1 upto n) - + fun open_all_all (Const ("all", _) $ Abs (n, T, b)) = apfst (cons (n, T)) (open_all_all b) | open_all_all t = ([], t) @@ -59,24 +58,24 @@ val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq) val (head, args) = strip_comb f_args - val invalid_head_msg = "Head symbol of left hand side must be " ^ plural "" "one out of " fnames ^ commas_quote fnames + val invalid_head_msg = "Head symbol of left hand side must be " ^ plural "" "one out of " fnames ^ commas_quote fnames val fname = fst (dest_Free head) handle TERM _ => input_error invalid_head_msg val _ = if fname mem fnames then () else input_error invalid_head_msg - + fun add_bvs t is = add_loose_bnos (t, 0, is) val rvs = (add_bvs rhs [] \\ fold add_bvs args []) |> map (fst o nth (rev qs)) - val _ = if null rvs then () - else input_error ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs + val _ = if null rvs then () + else input_error ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs ^ " occur" ^ plural "s" "" rvs ^ " on right hand side only:") val _ = (fold o fold_aterms) (fn Free (n, _) => if n mem fnames - then input_error "Recursive Calls not allowed in premises:" + then input_error "Recursive Calls not allowed in premises:" else I | _ => I) gs () @@ -84,11 +83,11 @@ val arities' = case Symtab.lookup arities fname of NONE => Symtab.update (fname, k) arities - | SOME i => if (i <> k) + | SOME i => if (i <> k) then input_error ("Function " ^ quote fname ^ " has different numbers of arguments in different equations") else arities in - ((fname, qs, gs, args, rhs), arities') + ((fname, qs, gs, args, rhs), arities') end fun get_part fname = @@ -96,7 +95,7 @@ (* FIXME *) fun mk_prod_abs e (t1, t2) = - let + let val bTs = rev (map snd e) val T1 = fastype_of1 (bTs, t1) val T2 = fastype_of1 (bTs, t2) @@ -107,55 +106,55 @@ fun analyze_eqs ctxt fs eqs = let - val fnames = map fst fs + val fnames = map fst fs val (fqgars, arities) = fold_map (split_def ctxt fnames) eqs Symtab.empty - fun curried_types (fname, fT) = - let + fun curried_types (fname, fT) = + let val k = the_default 1 (Symtab.lookup arities fname) - val (caTs, uaTs) = chop k (binder_types fT) - in - (caTs, uaTs ---> body_type fT) - end + val (caTs, uaTs) = chop k (binder_types fT) + in + (caTs, uaTs ---> body_type fT) + end - val (caTss, resultTs) = split_list (map curried_types fs) - val argTs = map (foldr1 HOLogic.mk_prodT) caTss + val (caTss, resultTs) = split_list (map curried_types fs) + val argTs = map (foldr1 HOLogic.mk_prodT) caTss - val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs - val (ST, streeA, pthsA) = SumTools.mk_tree argTs + val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs + val (ST, streeA, pthsA) = SumTools.mk_tree argTs - val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name fnames) - val fsum_type = ST --> RST + val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name fnames) + val fsum_type = ST --> RST val ([fsum_var_name], _) = Variable.add_fixes [ def_name ^ "_sum" ] ctxt val fsum_var = (fsum_var_name, fsum_type) - fun define (fvar as (n, T)) caTs pthA pthR = - let - val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *) + fun define (fvar as (n, T)) caTs pthA pthR = + let + val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *) + + val f_exp = SumTools.mk_proj streeR pthR (Free fsum_var $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars)) + val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) - val f_exp = SumTools.mk_proj streeR pthR (Free fsum_var $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars)) - val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) - - val rew = (n, fold_rev lambda vars f_exp) - in - (MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew) - end + val rew = (n, fold_rev lambda vars f_exp) + in + (MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew) + end - val (parts, rews) = split_list (map4 define fs caTss pthsA pthsR) + val (parts, rews) = split_list (map4 define fs caTss pthsA pthsR) fun convert_eqs (f, qs, gs, args, rhs) = let val MutualPart {pthA, pthR, ...} = get_part f parts in - (qs, gs, SumTools.mk_inj streeA pthA (foldr1 (mk_prod_abs qs) args), - SumTools.mk_inj streeR pthR (replace_frees rews rhs) + (qs, gs, SumTools.mk_inj streeA pthA (foldr1 (mk_prod_abs qs) args), + SumTools.mk_inj streeR pthR (replace_frees rews rhs) |> Envir.norm_term (Envir.empty 0)) end - val qglrs = map convert_eqs fqgars + val qglrs = map convert_eqs fqgars in - Mutual {defname=def_name,fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, + Mutual {defname=def_name,fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE} end @@ -166,28 +165,28 @@ let fun def ((MutualPart {fvar=(fname, fT), cargTs, pthA, pthR, f_def, ...}), (_, mixfix)) lthy = let - val ((f, (_, f_defthm)), lthy') = LocalTheory.def ((fname, mixfix), - ((fname ^ "_def", []), Term.subst_bound (fsum, f_def))) + val ((f, (_, f_defthm)), lthy') = LocalTheory.def ((fname, mixfix), + ((fname ^ "_def", []), Term.subst_bound (fsum, f_def))) lthy in - (MutualPart {fvar=(fname, fT), cargTs=cargTs, pthA=pthA, pthR=pthR, f_def=f_def, + (MutualPart {fvar=(fname, fT), cargTs=cargTs, pthA=pthA, pthR=pthR, f_def=f_def, f=SOME f, f_defthm=SOME f_defthm }, lthy') end val Mutual { defname, fsum_var, ST, RST, streeA, streeR, parts, fqgars, qglrs, ... } = mutual - val (parts', lthy') = fold_map def (parts ~~ fixes) lthy + val (parts', lthy') = fold_map def (parts ~~ fixes) lthy in - (Mutual { defname=defname, fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts', + (Mutual { defname=defname, fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts', fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum }, lthy') end fun prepare_fundef_mutual fixes eqss default lthy = - let - val mutual = analyze_eqs lthy (map fst fixes) eqss - val Mutual {defname, fsum_var=(n, T), qglrs, ...} = mutual + let + val mutual = analyze_eqs lthy (map fst fixes) eqss + val Mutual {defname, fsum_var=(n, T), qglrs, ...} = mutual val (prep_result, fsum, lthy') = FundefPrep.prepare_fundef defname (n, T, NoSyn) qglrs default lthy @@ -199,13 +198,13 @@ (* Beta-reduce both sides of a meta-equality *) -fun beta_norm_eq thm = +fun beta_norm_eq thm = let - val (lhs, rhs) = dest_equals (cprop_of thm) - val lhs_conv = beta_conversion false lhs - val rhs_conv = beta_conversion false rhs + val (lhs, rhs) = dest_equals (cprop_of thm) + val lhs_conv = beta_conversion false lhs + val rhs_conv = beta_conversion false rhs in - transitive (symmetric lhs_conv) (transitive thm rhs_conv) + transitive (symmetric lhs_conv) (transitive thm rhs_conv) end fun beta_reduce thm = Thm.equal_elim (Thm.beta_conversion true (cprop_of thm)) thm @@ -226,7 +225,7 @@ val cqs = map (cterm_of thy) qs val ags = map (assume o cterm_of thy) gs - + val import = fold forall_elim cqs #> fold implies_elim_swp ags @@ -254,18 +253,18 @@ in reflexive (cterm_of thy (lambda x (SumTools.mk_proj streeR pthR x))) (* PR(x) == PR(x) *) |> (fn it => combination it (simp RS eq_reflection)) - |> beta_norm_eq (* PR(S(I(as))) == PR(IR(...)) *) + |> beta_norm_eq (* PR(S(I(as))) == PR(IR(...)) *) |> transitive f_def_inst (* f ... == PR(IR(...)) *) |> simplify (HOL_basic_ss addsimps [SumTools.projl_inl, SumTools.projr_inr]) (* f ... == ... *) |> simplify (HOL_basic_ss addsimps all_f_defs) (* f ... == ... *) |> (fn it => it RS meta_eq_to_obj_eq) |> restore_cond |> export - end + end (* FIXME HACK *) -fun mk_applied_form ctxt caTs thm = +fun mk_applied_form ctxt caTs thm = let val thy = ProofContext.theory_of ctxt val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *) @@ -276,73 +275,77 @@ |> forall_elim_vars 0 end - + fun mutual_induct_rules thy induct all_f_defs (Mutual {RST, parts, streeA, ...}) = let - fun mk_P (MutualPart {cargTs, ...}) Pname = - let - val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs - val atup = foldr1 HOLogic.mk_prod avars - in - tupled_lambda atup (list_comb (Free (Pname, cargTs ---> HOLogic.boolT), avars)) - end - - val Ps = map2 mk_P parts (mutual_induct_Pnames (length parts)) - val case_exp = SumTools.mk_sumcases streeA HOLogic.boolT Ps - - val induct_inst = - forall_elim (cterm_of thy case_exp) induct - |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules)) - |> full_simplify (HOL_basic_ss addsimps all_f_defs) + fun mk_P (MutualPart {cargTs, ...}) Pname = + let + val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs + val atup = foldr1 HOLogic.mk_prod avars + in + tupled_lambda atup (list_comb (Free (Pname, cargTs ---> HOLogic.boolT), avars)) + end + + val Ps = map2 mk_P parts (mutual_induct_Pnames (length parts)) + val case_exp = SumTools.mk_sumcases streeA HOLogic.boolT Ps - fun mk_proj rule (MutualPart {cargTs, pthA, ...}) = - let - val afs = map_index (fn (i,T) => Free ("a" ^ string_of_int i, T)) cargTs - val inj = SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod afs) - in - rule - |> forall_elim (cterm_of thy inj) - |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules)) - end + val induct_inst = + forall_elim (cterm_of thy case_exp) induct + |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules)) + |> full_simplify (HOL_basic_ss addsimps all_f_defs) + + fun mk_proj rule (MutualPart {cargTs, pthA, ...}) = + let + val afs = map_index (fn (i,T) => Free ("a" ^ string_of_int i, T)) cargTs + val inj = SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod afs) + in + rule + |> forall_elim (cterm_of thy inj) + |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules)) + end in map (mk_proj induct_inst) parts end - + fun mk_partial_rules_mutual lthy (m as Mutual {RST, parts, streeR, fqgars, ...}) data prep_result = let val thy = ProofContext.theory_of lthy - + + (* FIXME !? *) + val expand = Assumption.export false lthy (LocalTheory.target_of lthy); + val expand_term = Drule.term_rule thy expand; + val result = FundefProof.mk_partial_rules thy data prep_result val FundefResult {f, G, R, completeness, psimps, subset_pinduct,simple_pinduct,total_intro,dom_intros} = result - - val all_f_defs = map (fn MutualPart {f_defthm = SOME f_def, cargTs, ...} => - mk_applied_form lthy cargTs (symmetric (Thm.freezeT f_def))) + + val all_f_defs = map (fn MutualPart {f_defthm = SOME f_def, cargTs, ...} => + mk_applied_form lthy cargTs (symmetric (Thm.freezeT f_def))) parts - |> print - - fun mk_mpsimp fqgar sum_psimp = + + fun mk_mpsimp fqgar sum_psimp = in_context lthy fqgar (recover_mutual_psimp thy RST streeR all_f_defs parts) sum_psimp - + val mpsimps = map2 mk_mpsimp fqgars psimps - + val minducts = mutual_induct_rules thy simple_pinduct all_f_defs m val termination = full_simplify (HOL_basic_ss addsimps all_f_defs) total_intro in - FundefMResult { f=f, G=G, R=R, - psimps=mpsimps, subset_pinducts=[subset_pinduct], simple_pinducts=minducts, - cases=completeness, termination=termination, domintros=dom_intros } + FundefMResult { f=expand_term f, G=expand_term G, R=expand_term R, + psimps=map expand mpsimps, subset_pinducts=[expand subset_pinduct], simple_pinducts=map expand minducts, + cases=expand completeness, termination=expand termination, + domintros=map expand dom_intros } end -(* puts an object in the "right bucket" *) +(* puts an object in the "right bucket" *) fun store_grouped P x [] = [] - | store_grouped P x ((l, xs)::bs) = + | store_grouped P x ((l, xs)::bs) = if P (x, l) then ((l, x::xs)::bs) else ((l, xs)::store_grouped P x bs) fun sort_by_function (Mutual {fqgars, ...}) names xs = @@ -352,32 +355,4 @@ |> map (snd #> map snd) (* and remove the labels afterwards *) - - - end - - - - - - - - - - - - - - - - - - - - - - - - -