diff -r a48d4bd3faaa -r cadccda5be03 src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Wed Sep 25 09:35:37 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Wed Sep 25 10:17:18 2013 +0200 @@ -60,7 +60,8 @@ val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ list -> typ -> term -> term - val fold_rev_corec_code_rhs: (term -> term list -> 'a -> 'a) -> term -> 'a -> 'a + val fold_rev_corec_code_rhs: Proof.context -> (term -> term list -> 'a -> 'a) -> typ list -> + term -> 'a -> 'a val simplify_bool_ifs: theory -> term -> term list val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) -> ((term * term list list) list) list -> local_theory -> @@ -221,18 +222,6 @@ massage_rec end; -fun fold_rev_let_if f = - let - fun fld t = - (case Term.strip_comb t of - (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1)) - | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches - | (Const (@{const_name nat_case}, _), args) => fold_rev fld (fst (split_last args)) - | _ => f t) - in - fld - end; - val massage_direct_corec_call = massage_let_if; fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t = @@ -319,7 +308,29 @@ fun massage_corec_code_rhs ctxt massage_ctr = massage_let_if ctxt (K false) (uncurry massage_ctr o Term.strip_comb); -fun fold_rev_corec_code_rhs f = fold_rev_let_if (uncurry f o Term.strip_comb); +(* TODO: also support old-style datatypes. + (Ideally, we would have a proper registry for these things.) *) +fun case_of ctxt = + fp_sugar_of ctxt #> Option.map (fst o dest_Const o #casex o of_fp_sugar #ctr_sugars); + +fun fold_rev_let_if ctxt f bound_Ts = + let + fun fld t = + (case Term.strip_comb t of + (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1)) + | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches + | (Const (c, _), args as _ :: _) => + let val (branches, obj) = split_last args in + (case fastype_of1 (bound_Ts, obj) of + Type (T_name, _) => if case_of ctxt T_name = SOME c then fold_rev fld branches else f t + | _ => f t) + end + | _ => f t) + in + fld + end; + +fun fold_rev_corec_code_rhs ctxt f = fold_rev_let_if ctxt (uncurry f o Term.strip_comb); fun add_conjuncts (Const (@{const_name conj}, _) $ t $ t') = add_conjuncts t o add_conjuncts t' | add_conjuncts t = cons t;