src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
changeset 55341 3d2c97392e25
parent 55339 f09037306f25
child 55342 1bd9e637ac9f
--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Wed Feb 05 18:19:25 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Wed Feb 05 23:30:02 2014 +0100
@@ -143,7 +143,7 @@
       |> (fn [cs] => cs | css => [s_disjs (map s_conjs css)])
   end;
 
-fun fold_rev_let_if_case ctxt f bound_Ts t =
+fun fold_rev_let_if_case ctxt f bound_Ts =
   let
     val thy = Proof_Context.theory_of ctxt;
 
@@ -158,17 +158,16 @@
             (case fastype_of1 (bound_Ts, nth args n) of
               Type (s, Ts) =>
               (case dest_case ctxt s Ts t of
-                SOME (ctr_sugar as {sel_splits = _ :: _, ...}, conds', branches) =>
-                apfst (cons ctr_sugar) o fold_rev (uncurry fld)
-                  (map (append conds o conjuncts_s) conds' ~~ branches)
-              | _ => apsnd (f conds t))
-            | _ => apsnd (f conds t))
+                SOME ({sel_splits = _ :: _, ...}, conds', branches) =>
+                fold_rev (uncurry fld) (map (append conds o conjuncts_s) conds' ~~ branches)
+              | _ => f conds t)
+            | _ => f conds t)
           else
-            apsnd (f conds t)
+            f conds t
         end
-      | _ => apsnd (f conds t))
+      | _ => f conds t)
   in
-    fld [] t o pair []
+    fld []
   end;
 
 fun case_of ctxt s =
@@ -336,10 +335,23 @@
     (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
 
 fun fold_rev_corec_code_rhs ctxt f =
-  snd ooo fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);
+  fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);
 
 fun case_thms_of_term ctxt bound_Ts t =
-  let val (ctr_sugars, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t () in
+  let
+    fun ctr_sugar_of_case c s =
+      (case ctr_sugar_of ctxt s of
+        SOME (ctr_sugar as {casex = Const (c', _), ...}) => if c' = c then SOME ctr_sugar else NONE
+      | _ => NONE);
+    fun add_ctr_sugar (s, Type (@{type_name fun}, [_, T])) =
+        binder_types T
+        |> map_filter (try (fst o dest_Type))
+        |> distinct (op =)
+        |> map_filter (ctr_sugar_of_case s)
+      | add_ctr_sugar _ = [];
+
+    val ctr_sugars = maps add_ctr_sugar (Term.add_consts t []);
+  in
     (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #disc_exhausts ctr_sugars,
      maps #sel_splits ctr_sugars, maps #sel_split_asms ctr_sugars)
   end;
@@ -884,10 +896,9 @@
   let
     val sel_no = find_first (curry (op =) ctr o #ctr) basic_ctr_specs
       |> find_index (curry (op =) sel) o #sels o the;
-    fun find t = if has_call t then snd (fold_rev_let_if_case ctxt (K cons) [] t []) else [];
   in
-    find rhs_term
-    |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
+    K (if has_call rhs_term then fold_rev_let_if_case ctxt (K cons) [] rhs_term [] else [])
+    |> nth_map sel_no |> AList.map_entry (op =) ctr
   end;
 
 fun applied_fun_of fun_name fun_T fun_args =