improved 'datatype_new_compat': generate more fixpoint equations for types like 'datatype_new x = C (x list) (x list)' (here, one equation for each x list instead of a single for both), for higher compatibility + code generation attributes on the recursor
authorblanchet
Fri Feb 14 15:03:23 2014 +0100 (2014-02-14)
changeset 55479ece4910c3ea0
parent 55478 3a6efda01da4
child 55480 59cc4a8bc28a
improved 'datatype_new_compat': generate more fixpoint equations for types like 'datatype_new x = C (x list) (x list)' (here, one equation for each x list instead of a single for both), for higher compatibility + code generation attributes on the recursor
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_compat.ML
     1.1 --- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 14:54:08 2014 +0100
     1.2 +++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
     1.3 @@ -15,8 +15,6 @@
     1.4      (BNF_FP_Def_Sugar.fp_sugar list
     1.5       * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
     1.6      * local_theory
     1.7 -  val indexify_callsss: BNF_FP_Def_Sugar.fp_sugar -> (term * term list list) list ->
     1.8 -    term list list list
     1.9    val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
    1.10      (term * term list list) list list -> local_theory ->
    1.11      (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
    1.12 @@ -112,8 +110,8 @@
    1.13  
    1.14      val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
    1.15  
    1.16 -    fun incompatible_calls t1 t2 =
    1.17 -      error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
    1.18 +    fun incompatible_calls ts =
    1.19 +      error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ commas (map qsotm ts));
    1.20      fun nested_self_call t =
    1.21        error ("Unsupported nested self-call " ^ qsotm t);
    1.22  
    1.23 @@ -147,13 +145,20 @@
    1.24        ||>> variant_tfrees fp_b_names;
    1.25  
    1.26      fun check_call_dead live_call call =
    1.27 -      if null (get_indices call) then () else incompatible_calls live_call call;
    1.28 +      if null (get_indices call) then () else incompatible_calls [live_call, call];
    1.29  
    1.30 -    fun freeze_fpTs_simple (T as Type (s, Ts)) =
    1.31 +    fun freeze_fpTs_default (T as Type (s, Ts)) =
    1.32          (case find_index (curry (op =) T) fpTs of
    1.33 -          ~1 => Type (s, map freeze_fpTs_simple Ts)
    1.34 +          ~1 => Type (s, map freeze_fpTs_default Ts)
    1.35          | kk => nth Xs kk)
    1.36 -      | freeze_fpTs_simple T = T;
    1.37 +      | freeze_fpTs_default T = T;
    1.38 +
    1.39 +    fun freeze_fpTs_simple calls (T as Type (s, Ts)) =
    1.40 +        (case fold (union (op =)) (map get_indices calls) [] of
    1.41 +          [] => freeze_fpTs_default T
    1.42 +        | [kk] => nth Xs kk
    1.43 +        | _ => incompatible_calls calls)
    1.44 +      | freeze_fpTs_simple _ T = T;
    1.45  
    1.46      fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
    1.47          (T as Type (s, Ts)) =
    1.48 @@ -167,7 +172,7 @@
    1.49          (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
    1.50            ([], _) =>
    1.51            (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
    1.52 -            ([], _) => freeze_fpTs_simple T
    1.53 +            ([], _) => freeze_fpTs_simple calls T
    1.54            | callsp => freeze_fpTs_map fpT callsp T)
    1.55          | callsp => freeze_fpTs_map fpT callsp T)
    1.56        | freeze_fpTs _ _ T = T;
    1.57 @@ -251,6 +256,7 @@
    1.58        end)
    1.59    end;
    1.60  
    1.61 +(* TODO: needed? *)
    1.62  fun indexify_callsss fp_sugar callsss =
    1.63    let
    1.64      val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
     2.1 --- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 14:54:08 2014 +0100
     2.2 +++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 15:03:23 2014 +0100
     2.3 @@ -27,6 +27,8 @@
     2.4  
     2.5  val compatN = "compat_";
     2.6  
     2.7 +val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
     2.8 +
     2.9  (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
    2.10  fun datatype_new_compat_cmd raw_fpT_names lthy =
    2.11    let
    2.12 @@ -54,44 +56,63 @@
    2.13      val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
    2.14      val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
    2.15  
    2.16 -    fun add_nested_types_of (T as Type (s, _)) seen =
    2.17 -      if member (op =) seen T then
    2.18 -        seen
    2.19 -      else if s = @{type_name fun} then
    2.20 -        (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
    2.21 -      else
    2.22 -        (case try lfp_sugar_of s of
    2.23 -          SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
    2.24 -          let
    2.25 -            val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
    2.26 -            val substT = Term.typ_subst_TVars rho;
    2.27 -
    2.28 -            val mutual_Ts = map substT mutual_Ts0;
    2.29 +    fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk =
    2.30 +      (case try lfp_sugar_of s of
    2.31 +        SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
    2.32 +        let
    2.33 +          val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
    2.34 +          val substT = Term.typ_subst_TVars rho;
    2.35 +          val mutual_Ts = map substT mutual_Ts0;
    2.36 +          val mutual_nn = length mutual_Ts;
    2.37 +          val mutual_kks = kk upto kk + mutual_nn - 1;
    2.38 +          val mutual_Tkks = mutual_Ts ~~ mutual_kks;
    2.39  
    2.40 -            fun add_interesting_subtypes (U as Type (_, Us)) =
    2.41 -                (case filter (exists_subtype_in mutual_Ts) Us of [] => I
    2.42 -                | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
    2.43 -              | add_interesting_subtypes _ = I;
    2.44 +          fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _))
    2.45 +                (accum as (Tkssss, kk')) =
    2.46 +              if s = @{type_name fun} then
    2.47 +                if exists_subtype_in mutual_Ts U then
    2.48 +                  (warning "Incomplete support for recursion through functions -- \
    2.49 +                     \'primrec' will fail";
    2.50 +                   Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
    2.51 +                else
    2.52 +                  ([], accum)
    2.53 +              else
    2.54 +                (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
    2.55 +                  SOME kk => ([kk], accum)
    2.56 +                | NONE =>
    2.57 +                  if exists_subtype_in mutual_Ts U then
    2.58 +                    ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss)
    2.59 +                  else
    2.60 +                    ([], accum))
    2.61 +            | Tindices_of_ctr_arg _ _ accum = ([], accum);
    2.62  
    2.63 -            val ctrs = maps #ctrs ctr_sugars;
    2.64 -            val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
    2.65 -            val subTs = fold add_interesting_subtypes ctr_Ts [];
    2.66 -          in
    2.67 -            fold add_nested_types_of subTs (seen @ mutual_Ts)
    2.68 -          end
    2.69 -        | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
    2.70 -            " not corresponding to new-style datatype (cf. \"datatype_new\")"));
    2.71 +          fun Tindicesss_of_mutual_type T kk ctr_Tss =
    2.72 +            fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss
    2.73 +            #>> pair T;
    2.74  
    2.75 -    val Ts = add_nested_types_of fpT1 [];
    2.76 +          val ctrss = map #ctrs ctr_sugars;
    2.77 +          val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
    2.78 +        in
    2.79 +          ([], kk + mutual_nn)
    2.80 +          |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
    2.81 +          |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk))
    2.82 +        end
    2.83 +      | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
    2.84 +          " not corresponding to new-style datatype (cf. \"datatype_new\")"));
    2.85 +
    2.86 +    fun get_indices (Bound kk) = [kk];
    2.87 +
    2.88 +    val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
    2.89 +    val Ts = map fst Tkkssss;
    2.90 +    val callssss = map (map (map (map Bound)) o snd) Tkkssss;
    2.91 +
    2.92      val b_names = map base_name_of_typ Ts;
    2.93      val compat_b_names = map (prefix compatN) b_names;
    2.94      val compat_bs = map Binding.name compat_b_names;
    2.95      val common_name = compatN ^ mk_common_name b_names;
    2.96      val nn_fp = length fpTs;
    2.97      val nn = length Ts;
    2.98 -    val get_indices = K [];
    2.99      val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
   2.100 -    val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
   2.101  
   2.102      val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
   2.103        if nn > nn_fp then
   2.104 @@ -142,7 +163,7 @@
   2.105            val notes =
   2.106              [(foldN, fold_thmss, []),
   2.107               (inductN, map single induct_thms, induct_attrs),
   2.108 -             (recN, rec_thmss, [])]
   2.109 +             (recN, rec_thmss, code_nitpicksimp_simp_attrs)]
   2.110              |> filter_out (null o #2)
   2.111              |> maps (fn (thmN, thmss, attrs) =>
   2.112                if forall null thmss then