--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Wed Sep 25 09:35:37 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Wed Sep 25 10:17:18 2013 +0200
@@ -60,7 +60,8 @@
val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ list -> typ ->
term -> term
- val fold_rev_corec_code_rhs: (term -> term list -> 'a -> 'a) -> term -> 'a -> 'a
+ val fold_rev_corec_code_rhs: Proof.context -> (term -> term list -> 'a -> 'a) -> typ list ->
+ term -> 'a -> 'a
val simplify_bool_ifs: theory -> term -> term list
val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
((term * term list list) list) list -> local_theory ->
@@ -221,18 +222,6 @@
massage_rec
end;
-fun fold_rev_let_if f =
- let
- fun fld t =
- (case Term.strip_comb t of
- (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
- | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
- | (Const (@{const_name nat_case}, _), args) => fold_rev fld (fst (split_last args))
- | _ => f t)
- in
- fld
- end;
-
val massage_direct_corec_call = massage_let_if;
fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t =
@@ -319,7 +308,29 @@
fun massage_corec_code_rhs ctxt massage_ctr =
massage_let_if ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
-fun fold_rev_corec_code_rhs f = fold_rev_let_if (uncurry f o Term.strip_comb);
+(* TODO: also support old-style datatypes.
+ (Ideally, we would have a proper registry for these things.) *)
+fun case_of ctxt =
+ fp_sugar_of ctxt #> Option.map (fst o dest_Const o #casex o of_fp_sugar #ctr_sugars);
+
+fun fold_rev_let_if ctxt f bound_Ts =
+ let
+ fun fld t =
+ (case Term.strip_comb t of
+ (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
+ | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
+ | (Const (c, _), args as _ :: _) =>
+ let val (branches, obj) = split_last args in
+ (case fastype_of1 (bound_Ts, obj) of
+ Type (T_name, _) => if case_of ctxt T_name = SOME c then fold_rev fld branches else f t
+ | _ => f t)
+ end
+ | _ => f t)
+ in
+ fld
+ end;
+
+fun fold_rev_corec_code_rhs ctxt f = fold_rev_let_if ctxt (uncurry f o Term.strip_comb);
fun add_conjuncts (Const (@{const_name conj}, _) $ t $ t') = add_conjuncts t o add_conjuncts t'
| add_conjuncts t = cons t;