--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML Wed Feb 05 18:19:25 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML Wed Feb 05 23:30:02 2014 +0100
@@ -143,7 +143,7 @@
|> (fn [cs] => cs | css => [s_disjs (map s_conjs css)])
end;
-fun fold_rev_let_if_case ctxt f bound_Ts t =
+fun fold_rev_let_if_case ctxt f bound_Ts =
let
val thy = Proof_Context.theory_of ctxt;
@@ -158,17 +158,16 @@
(case fastype_of1 (bound_Ts, nth args n) of
Type (s, Ts) =>
(case dest_case ctxt s Ts t of
- SOME (ctr_sugar as {sel_splits = _ :: _, ...}, conds', branches) =>
- apfst (cons ctr_sugar) o fold_rev (uncurry fld)
- (map (append conds o conjuncts_s) conds' ~~ branches)
- | _ => apsnd (f conds t))
- | _ => apsnd (f conds t))
+ SOME ({sel_splits = _ :: _, ...}, conds', branches) =>
+ fold_rev (uncurry fld) (map (append conds o conjuncts_s) conds' ~~ branches)
+ | _ => f conds t)
+ | _ => f conds t)
else
- apsnd (f conds t)
+ f conds t
end
- | _ => apsnd (f conds t))
+ | _ => f conds t)
in
- fld [] t o pair []
+ fld []
end;
fun case_of ctxt s =
@@ -336,10 +335,23 @@
(fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb);
fun fold_rev_corec_code_rhs ctxt f =
- snd ooo fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);
+ fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb);
fun case_thms_of_term ctxt bound_Ts t =
- let val (ctr_sugars, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t () in
+ let
+ fun ctr_sugar_of_case c s =
+ (case ctr_sugar_of ctxt s of
+ SOME (ctr_sugar as {casex = Const (c', _), ...}) => if c' = c then SOME ctr_sugar else NONE
+ | _ => NONE);
+ fun add_ctr_sugar (s, Type (@{type_name fun}, [_, T])) =
+ binder_types T
+ |> map_filter (try (fst o dest_Type))
+ |> distinct (op =)
+ |> map_filter (ctr_sugar_of_case s)
+ | add_ctr_sugar _ = [];
+
+ val ctr_sugars = maps add_ctr_sugar (Term.add_consts t []);
+ in
(maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #disc_exhausts ctr_sugars,
maps #sel_splits ctr_sugars, maps #sel_split_asms ctr_sugars)
end;
@@ -884,10 +896,9 @@
let
val sel_no = find_first (curry (op =) ctr o #ctr) basic_ctr_specs
|> find_index (curry (op =) sel) o #sels o the;
- fun find t = if has_call t then snd (fold_rev_let_if_case ctxt (K cons) [] t []) else [];
in
- find rhs_term
- |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
+ K (if has_call rhs_term then fold_rev_let_if_case ctxt (K cons) [] rhs_term [] else [])
+ |> nth_map sel_no |> AList.map_entry (op =) ctr
end;
fun applied_fun_of fun_name fun_T fun_args =