# HG changeset patch # User blanchet # Date 1392386603 -3600 # Node ID ece4910c3ea01b02697ee0329dc11452be61f4be # Parent 3a6efda01da476f0cde122a2deab2e54defedaf7 improved 'datatype_new_compat': generate more fixpoint equations for types like 'datatype_new x = C (x list) (x list)' (here, one equation for each x list instead of a single for both), for higher compatibility + code generation attributes on the recursor diff -r 3a6efda01da4 -r ece4910c3ea0 src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML --- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML Fri Feb 14 14:54:08 2014 +0100 +++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML Fri Feb 14 15:03:23 2014 +0100 @@ -15,8 +15,6 @@ (BNF_FP_Def_Sugar.fp_sugar list * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option)) * local_theory - val indexify_callsss: BNF_FP_Def_Sugar.fp_sugar -> (term * term list list) list -> - term list list list val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) -> (term * term list list) list list -> local_theory -> (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list @@ -112,8 +110,8 @@ val qsotm = quote o Syntax.string_of_term no_defs_lthy0; - fun incompatible_calls t1 t2 = - error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2); + fun incompatible_calls ts = + error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ commas (map qsotm ts)); fun nested_self_call t = error ("Unsupported nested self-call " ^ qsotm t); @@ -147,13 +145,20 @@ ||>> variant_tfrees fp_b_names; fun check_call_dead live_call call = - if null (get_indices call) then () else incompatible_calls live_call call; + if null (get_indices call) then () else incompatible_calls [live_call, call]; - fun freeze_fpTs_simple (T as Type (s, Ts)) = + fun freeze_fpTs_default (T as Type (s, Ts)) = (case find_index (curry (op =) T) fpTs of - ~1 => Type (s, map freeze_fpTs_simple Ts) + ~1 => Type (s, map freeze_fpTs_default Ts) | kk => nth Xs kk) - | freeze_fpTs_simple T = T; + | freeze_fpTs_default T = T; + + fun freeze_fpTs_simple calls (T as Type (s, Ts)) = + (case fold (union (op =)) (map get_indices calls) [] of + [] => freeze_fpTs_default T + | [kk] => nth Xs kk + | _ => incompatible_calls calls) + | freeze_fpTs_simple _ T = T; fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls)) (T as Type (s, Ts)) = @@ -167,7 +172,7 @@ (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of ([], _) => (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of - ([], _) => freeze_fpTs_simple T + ([], _) => freeze_fpTs_simple calls T | callsp => freeze_fpTs_map fpT callsp T) | callsp => freeze_fpTs_map fpT callsp T) | freeze_fpTs _ _ T = T; @@ -251,6 +256,7 @@ end) end; +(* TODO: needed? *) fun indexify_callsss fp_sugar callsss = let val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar; diff -r 3a6efda01da4 -r ece4910c3ea0 src/HOL/Tools/BNF/bnf_lfp_compat.ML --- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 14:54:08 2014 +0100 +++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 15:03:23 2014 +0100 @@ -27,6 +27,8 @@ val compatN = "compat_"; +val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]}; + (* TODO: graceful failure for local datatypes -- perhaps by making the command global *) fun datatype_new_compat_cmd raw_fpT_names lthy = let @@ -54,44 +56,63 @@ val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As; val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names'; - fun add_nested_types_of (T as Type (s, _)) seen = - if member (op =) seen T then - seen - else if s = @{type_name fun} then - (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen) - else - (case try lfp_sugar_of s of - SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) => - let - val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) []; - val substT = Term.typ_subst_TVars rho; - - val mutual_Ts = map substT mutual_Ts0; + fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk = + (case try lfp_sugar_of s of + SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) => + let + val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) []; + val substT = Term.typ_subst_TVars rho; + val mutual_Ts = map substT mutual_Ts0; + val mutual_nn = length mutual_Ts; + val mutual_kks = kk upto kk + mutual_nn - 1; + val mutual_Tkks = mutual_Ts ~~ mutual_kks; - fun add_interesting_subtypes (U as Type (_, Us)) = - (case filter (exists_subtype_in mutual_Ts) Us of [] => I - | Us' => insert (op =) U #> fold add_interesting_subtypes Us') - | add_interesting_subtypes _ = I; + fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _)) + (accum as (Tkssss, kk')) = + if s = @{type_name fun} then + if exists_subtype_in mutual_Ts U then + (warning "Incomplete support for recursion through functions -- \ + \'primrec' will fail"; + Tindices_of_ctr_arg parent_Tkks (range_type U) accum) + else + ([], accum) + else + (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of + SOME kk => ([kk], accum) + | NONE => + if exists_subtype_in mutual_Ts U then + ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss) + else + ([], accum)) + | Tindices_of_ctr_arg _ _ accum = ([], accum); - val ctrs = maps #ctrs ctr_sugars; - val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =); - val subTs = fold add_interesting_subtypes ctr_Ts []; - in - fold add_nested_types_of subTs (seen @ mutual_Ts) - end - | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^ - " not corresponding to new-style datatype (cf. \"datatype_new\")")); + fun Tindicesss_of_mutual_type T kk ctr_Tss = + fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss + #>> pair T; - val Ts = add_nested_types_of fpT1 []; + val ctrss = map #ctrs ctr_sugars; + val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss; + in + ([], kk + mutual_nn) + |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss + |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk)) + end + | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^ + " not corresponding to new-style datatype (cf. \"datatype_new\")")); + + fun get_indices (Bound kk) = [kk]; + + val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0; + val Ts = map fst Tkkssss; + val callssss = map (map (map (map Bound)) o snd) Tkkssss; + val b_names = map base_name_of_typ Ts; val compat_b_names = map (prefix compatN) b_names; val compat_bs = map Binding.name compat_b_names; val common_name = compatN ^ mk_common_name b_names; val nn_fp = length fpTs; val nn = length Ts; - val get_indices = K []; val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts; - val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0; val ((fp_sugars, (lfp_sugar_thms, _)), lthy) = if nn > nn_fp then @@ -142,7 +163,7 @@ val notes = [(foldN, fold_thmss, []), (inductN, map single induct_thms, induct_attrs), - (recN, rec_thmss, [])] + (recN, rec_thmss, code_nitpicksimp_simp_attrs)] |> filter_out (null o #2) |> maps (fn (thmN, thmss, attrs) => if forall null thmss then