improved 'datatype_new_compat': generate more fixpoint equations for types like 'datatype_new x = C (x list) (x list)' (here, one equation for each x list instead of a single for both), for higher compatibility + code generation attributes on the recursor
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML Fri Feb 14 14:54:08 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML Fri Feb 14 15:03:23 2014 +0100
@@ -15,8 +15,6 @@
(BNF_FP_Def_Sugar.fp_sugar list
* (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
* local_theory
- val indexify_callsss: BNF_FP_Def_Sugar.fp_sugar -> (term * term list list) list ->
- term list list list
val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
(term * term list list) list list -> local_theory ->
(typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
@@ -112,8 +110,8 @@
val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
- fun incompatible_calls t1 t2 =
- error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
+ fun incompatible_calls ts =
+ error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ commas (map qsotm ts));
fun nested_self_call t =
error ("Unsupported nested self-call " ^ qsotm t);
@@ -147,13 +145,20 @@
||>> variant_tfrees fp_b_names;
fun check_call_dead live_call call =
- if null (get_indices call) then () else incompatible_calls live_call call;
+ if null (get_indices call) then () else incompatible_calls [live_call, call];
- fun freeze_fpTs_simple (T as Type (s, Ts)) =
+ fun freeze_fpTs_default (T as Type (s, Ts)) =
(case find_index (curry (op =) T) fpTs of
- ~1 => Type (s, map freeze_fpTs_simple Ts)
+ ~1 => Type (s, map freeze_fpTs_default Ts)
| kk => nth Xs kk)
- | freeze_fpTs_simple T = T;
+ | freeze_fpTs_default T = T;
+
+ fun freeze_fpTs_simple calls (T as Type (s, Ts)) =
+ (case fold (union (op =)) (map get_indices calls) [] of
+ [] => freeze_fpTs_default T
+ | [kk] => nth Xs kk
+ | _ => incompatible_calls calls)
+ | freeze_fpTs_simple _ T = T;
fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
(T as Type (s, Ts)) =
@@ -167,7 +172,7 @@
(case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
([], _) =>
(case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
- ([], _) => freeze_fpTs_simple T
+ ([], _) => freeze_fpTs_simple calls T
| callsp => freeze_fpTs_map fpT callsp T)
| callsp => freeze_fpTs_map fpT callsp T)
| freeze_fpTs _ _ T = T;
@@ -251,6 +256,7 @@
end)
end;
+(* TODO: needed? *)
fun indexify_callsss fp_sugar callsss =
let
val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 14:54:08 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 15:03:23 2014 +0100
@@ -27,6 +27,8 @@
val compatN = "compat_";
+val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
+
(* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
fun datatype_new_compat_cmd raw_fpT_names lthy =
let
@@ -54,44 +56,63 @@
val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
- fun add_nested_types_of (T as Type (s, _)) seen =
- if member (op =) seen T then
- seen
- else if s = @{type_name fun} then
- (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
- else
- (case try lfp_sugar_of s of
- SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
- let
- val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
- val substT = Term.typ_subst_TVars rho;
-
- val mutual_Ts = map substT mutual_Ts0;
+ fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk =
+ (case try lfp_sugar_of s of
+ SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
+ let
+ val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
+ val substT = Term.typ_subst_TVars rho;
+ val mutual_Ts = map substT mutual_Ts0;
+ val mutual_nn = length mutual_Ts;
+ val mutual_kks = kk upto kk + mutual_nn - 1;
+ val mutual_Tkks = mutual_Ts ~~ mutual_kks;
- fun add_interesting_subtypes (U as Type (_, Us)) =
- (case filter (exists_subtype_in mutual_Ts) Us of [] => I
- | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
- | add_interesting_subtypes _ = I;
+ fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _))
+ (accum as (Tkssss, kk')) =
+ if s = @{type_name fun} then
+ if exists_subtype_in mutual_Ts U then
+ (warning "Incomplete support for recursion through functions -- \
+ \'primrec' will fail";
+ Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
+ else
+ ([], accum)
+ else
+ (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
+ SOME kk => ([kk], accum)
+ | NONE =>
+ if exists_subtype_in mutual_Ts U then
+ ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss)
+ else
+ ([], accum))
+ | Tindices_of_ctr_arg _ _ accum = ([], accum);
- val ctrs = maps #ctrs ctr_sugars;
- val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
- val subTs = fold add_interesting_subtypes ctr_Ts [];
- in
- fold add_nested_types_of subTs (seen @ mutual_Ts)
- end
- | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
- " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+ fun Tindicesss_of_mutual_type T kk ctr_Tss =
+ fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss
+ #>> pair T;
- val Ts = add_nested_types_of fpT1 [];
+ val ctrss = map #ctrs ctr_sugars;
+ val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
+ in
+ ([], kk + mutual_nn)
+ |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
+ |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk))
+ end
+ | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
+ " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+
+ fun get_indices (Bound kk) = [kk];
+
+ val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
+ val Ts = map fst Tkkssss;
+ val callssss = map (map (map (map Bound)) o snd) Tkkssss;
+
val b_names = map base_name_of_typ Ts;
val compat_b_names = map (prefix compatN) b_names;
val compat_bs = map Binding.name compat_b_names;
val common_name = compatN ^ mk_common_name b_names;
val nn_fp = length fpTs;
val nn = length Ts;
- val get_indices = K [];
val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
- val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
if nn > nn_fp then
@@ -142,7 +163,7 @@
val notes =
[(foldN, fold_thmss, []),
(inductN, map single induct_thms, induct_attrs),
- (recN, rec_thmss, [])]
+ (recN, rec_thmss, code_nitpicksimp_simp_attrs)]
|> filter_out (null o #2)
|> maps (fn (thmN, thmss, attrs) =>
if forall null thmss then