--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 14:54:08 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 15:03:23 2014 +0100
@@ -27,6 +27,8 @@
val compatN = "compat_";
+val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
+
(* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
fun datatype_new_compat_cmd raw_fpT_names lthy =
let
@@ -54,44 +56,63 @@
val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
- fun add_nested_types_of (T as Type (s, _)) seen =
- if member (op =) seen T then
- seen
- else if s = @{type_name fun} then
- (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
- else
- (case try lfp_sugar_of s of
- SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
- let
- val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
- val substT = Term.typ_subst_TVars rho;
-
- val mutual_Ts = map substT mutual_Ts0;
+ fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk =
+ (case try lfp_sugar_of s of
+ SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
+ let
+ val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
+ val substT = Term.typ_subst_TVars rho;
+ val mutual_Ts = map substT mutual_Ts0;
+ val mutual_nn = length mutual_Ts;
+ val mutual_kks = kk upto kk + mutual_nn - 1;
+ val mutual_Tkks = mutual_Ts ~~ mutual_kks;
- fun add_interesting_subtypes (U as Type (_, Us)) =
- (case filter (exists_subtype_in mutual_Ts) Us of [] => I
- | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
- | add_interesting_subtypes _ = I;
+ fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _))
+ (accum as (Tkssss, kk')) =
+ if s = @{type_name fun} then
+ if exists_subtype_in mutual_Ts U then
+ (warning "Incomplete support for recursion through functions -- \
+ \'primrec' will fail";
+ Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
+ else
+ ([], accum)
+ else
+ (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
+ SOME kk => ([kk], accum)
+ | NONE =>
+ if exists_subtype_in mutual_Ts U then
+ ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss)
+ else
+ ([], accum))
+ | Tindices_of_ctr_arg _ _ accum = ([], accum);
- val ctrs = maps #ctrs ctr_sugars;
- val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
- val subTs = fold add_interesting_subtypes ctr_Ts [];
- in
- fold add_nested_types_of subTs (seen @ mutual_Ts)
- end
- | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
- " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+ fun Tindicesss_of_mutual_type T kk ctr_Tss =
+ fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss
+ #>> pair T;
- val Ts = add_nested_types_of fpT1 [];
+ val ctrss = map #ctrs ctr_sugars;
+ val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
+ in
+ ([], kk + mutual_nn)
+ |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
+ |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk))
+ end
+ | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
+ " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+
+ fun get_indices (Bound kk) = [kk];
+
+ val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
+ val Ts = map fst Tkkssss;
+ val callssss = map (map (map (map Bound)) o snd) Tkkssss;
+
val b_names = map base_name_of_typ Ts;
val compat_b_names = map (prefix compatN) b_names;
val compat_bs = map Binding.name compat_b_names;
val common_name = compatN ^ mk_common_name b_names;
val nn_fp = length fpTs;
val nn = length Ts;
- val get_indices = K [];
val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
- val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
if nn > nn_fp then
@@ -142,7 +163,7 @@
val notes =
[(foldN, fold_thmss, []),
(inductN, map single induct_thms, induct_attrs),
- (recN, rec_thmss, [])]
+ (recN, rec_thmss, code_nitpicksimp_simp_attrs)]
|> filter_out (null o #2)
|> maps (fn (thmN, thmss, attrs) =>
if forall null thmss then