src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
changeset 53865 cadccda5be03
parent 53864 a48d4bd3faaa
child 53866 7c23df53af01
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 09:35:37 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML	Wed Sep 25 10:17:18 2013 +0200
@@ -60,7 +60,8 @@
   val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term
   val massage_corec_code_rhs: Proof.context -> (term -> term list -> term) -> typ list -> typ ->
     term -> term
-  val fold_rev_corec_code_rhs: (term -> term list -> 'a -> 'a) -> term -> 'a -> 'a
+  val fold_rev_corec_code_rhs: Proof.context -> (term -> term list -> 'a -> 'a) -> typ list ->
+    term -> 'a -> 'a
   val simplify_bool_ifs: theory -> term -> term list
   val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) ->
     ((term * term list list) list) list -> local_theory ->
@@ -221,18 +222,6 @@
     massage_rec
   end;
 
-fun fold_rev_let_if f =
-  let
-    fun fld t =
-      (case Term.strip_comb t of
-        (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
-      | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
-      | (Const (@{const_name nat_case}, _), args) => fold_rev fld (fst (split_last args))
-      | _ => f t)
-  in
-    fld
-  end;
-
 val massage_direct_corec_call = massage_let_if;
 
 fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t =
@@ -319,7 +308,29 @@
 fun massage_corec_code_rhs ctxt massage_ctr =
   massage_let_if ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
 
-fun fold_rev_corec_code_rhs f = fold_rev_let_if (uncurry f o Term.strip_comb);
+(* TODO: also support old-style datatypes.
+   (Ideally, we would have a proper registry for these things.) *)
+fun case_of ctxt =
+  fp_sugar_of ctxt #> Option.map (fst o dest_Const o #casex o of_fp_sugar #ctr_sugars);
+
+fun fold_rev_let_if ctxt f bound_Ts =
+  let
+    fun fld t =
+      (case Term.strip_comb t of
+        (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
+      | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
+      | (Const (c, _), args as _ :: _) =>
+        let val (branches, obj) = split_last args in
+          (case fastype_of1 (bound_Ts, obj) of
+            Type (T_name, _) => if case_of ctxt T_name = SOME c then fold_rev fld branches else f t
+          | _ => f t)
+        end
+      | _ => f t)
+  in
+    fld
+  end;
+
+fun fold_rev_corec_code_rhs ctxt f = fold_rev_let_if ctxt (uncurry f o Term.strip_comb);
 
 fun add_conjuncts (Const (@{const_name conj}, _) $ t $ t') = add_conjuncts t o add_conjuncts t'
   | add_conjuncts t = cons t;