improved 'datatype_new_compat': generate more fixpoint equations for types like 'datatype_new x = C (x list) (x list)' (here, one equation for each x list instead of a single for both), for higher compatibility + code generation attributes on the recursor
authorblanchet
Fri, 14 Feb 2014 15:03:23 +0100
changeset 55479 ece4910c3ea0
parent 55478 3a6efda01da4
child 55480 59cc4a8bc28a
improved 'datatype_new_compat': generate more fixpoint equations for types like 'datatype_new x = C (x list) (x list)' (here, one equation for each x list instead of a single for both), for higher compatibility + code generation attributes on the recursor
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_compat.ML
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 14:54:08 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
@@ -15,8 +15,6 @@
     (BNF_FP_Def_Sugar.fp_sugar list
      * (BNF_FP_Def_Sugar.lfp_sugar_thms option * BNF_FP_Def_Sugar.gfp_sugar_thms option))
     * local_theory
-  val indexify_callsss: BNF_FP_Def_Sugar.fp_sugar -> (term * term list list) list ->
-    term list list list
   val nested_to_mutual_fps: BNF_FP_Util.fp_kind -> binding list -> typ list -> (term -> int list) ->
     (term * term list list) list list -> local_theory ->
     (typ list * int list * BNF_FP_Def_Sugar.fp_sugar list
@@ -112,8 +110,8 @@
 
     val qsotm = quote o Syntax.string_of_term no_defs_lthy0;
 
-    fun incompatible_calls t1 t2 =
-      error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ qsotm t1 ^ " vs. " ^ qsotm t2);
+    fun incompatible_calls ts =
+      error ("Incompatible " ^ co_prefix fp ^ "recursive calls: " ^ commas (map qsotm ts));
     fun nested_self_call t =
       error ("Unsupported nested self-call " ^ qsotm t);
 
@@ -147,13 +145,20 @@
       ||>> variant_tfrees fp_b_names;
 
     fun check_call_dead live_call call =
-      if null (get_indices call) then () else incompatible_calls live_call call;
+      if null (get_indices call) then () else incompatible_calls [live_call, call];
 
-    fun freeze_fpTs_simple (T as Type (s, Ts)) =
+    fun freeze_fpTs_default (T as Type (s, Ts)) =
         (case find_index (curry (op =) T) fpTs of
-          ~1 => Type (s, map freeze_fpTs_simple Ts)
+          ~1 => Type (s, map freeze_fpTs_default Ts)
         | kk => nth Xs kk)
-      | freeze_fpTs_simple T = T;
+      | freeze_fpTs_default T = T;
+
+    fun freeze_fpTs_simple calls (T as Type (s, Ts)) =
+        (case fold (union (op =)) (map get_indices calls) [] of
+          [] => freeze_fpTs_default T
+        | [kk] => nth Xs kk
+        | _ => incompatible_calls calls)
+      | freeze_fpTs_simple _ T = T;
 
     fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
         (T as Type (s, Ts)) =
@@ -167,7 +172,7 @@
         (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
           ([], _) =>
           (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
-            ([], _) => freeze_fpTs_simple T
+            ([], _) => freeze_fpTs_simple calls T
           | callsp => freeze_fpTs_map fpT callsp T)
         | callsp => freeze_fpTs_map fpT callsp T)
       | freeze_fpTs _ _ T = T;
@@ -251,6 +256,7 @@
       end)
   end;
 
+(* TODO: needed? *)
 fun indexify_callsss fp_sugar callsss =
   let
     val {ctrs, ...} = of_fp_sugar #ctr_sugars fp_sugar;
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 14:54:08 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 15:03:23 2014 +0100
@@ -27,6 +27,8 @@
 
 val compatN = "compat_";
 
+val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
+
 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
 fun datatype_new_compat_cmd raw_fpT_names lthy =
   let
@@ -54,44 +56,63 @@
     val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
 
-    fun add_nested_types_of (T as Type (s, _)) seen =
-      if member (op =) seen T then
-        seen
-      else if s = @{type_name fun} then
-        (warning "Partial support for recursion through functions -- 'primrec' will fail"; seen)
-      else
-        (case try lfp_sugar_of s of
-          SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
-          let
-            val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
-            val substT = Term.typ_subst_TVars rho;
-
-            val mutual_Ts = map substT mutual_Ts0;
+    fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk =
+      (case try lfp_sugar_of s of
+        SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
+        let
+          val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
+          val substT = Term.typ_subst_TVars rho;
+          val mutual_Ts = map substT mutual_Ts0;
+          val mutual_nn = length mutual_Ts;
+          val mutual_kks = kk upto kk + mutual_nn - 1;
+          val mutual_Tkks = mutual_Ts ~~ mutual_kks;
 
-            fun add_interesting_subtypes (U as Type (_, Us)) =
-                (case filter (exists_subtype_in mutual_Ts) Us of [] => I
-                | Us' => insert (op =) U #> fold add_interesting_subtypes Us')
-              | add_interesting_subtypes _ = I;
+          fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _))
+                (accum as (Tkssss, kk')) =
+              if s = @{type_name fun} then
+                if exists_subtype_in mutual_Ts U then
+                  (warning "Incomplete support for recursion through functions -- \
+                     \'primrec' will fail";
+                   Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
+                else
+                  ([], accum)
+              else
+                (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
+                  SOME kk => ([kk], accum)
+                | NONE =>
+                  if exists_subtype_in mutual_Ts U then
+                    ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss)
+                  else
+                    ([], accum))
+            | Tindices_of_ctr_arg _ _ accum = ([], accum);
 
-            val ctrs = maps #ctrs ctr_sugars;
-            val ctr_Ts = maps (binder_types o substT o fastype_of) ctrs |> distinct (op =);
-            val subTs = fold add_interesting_subtypes ctr_Ts [];
-          in
-            fold add_nested_types_of subTs (seen @ mutual_Ts)
-          end
-        | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
-            " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+          fun Tindicesss_of_mutual_type T kk ctr_Tss =
+            fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss
+            #>> pair T;
 
-    val Ts = add_nested_types_of fpT1 [];
+          val ctrss = map #ctrs ctr_sugars;
+          val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
+        in
+          ([], kk + mutual_nn)
+          |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
+          |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk))
+        end
+      | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
+          " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+
+    fun get_indices (Bound kk) = [kk];
+
+    val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
+    val Ts = map fst Tkkssss;
+    val callssss = map (map (map (map Bound)) o snd) Tkkssss;
+
     val b_names = map base_name_of_typ Ts;
     val compat_b_names = map (prefix compatN) b_names;
     val compat_bs = map Binding.name compat_b_names;
     val common_name = compatN ^ mk_common_name b_names;
     val nn_fp = length fpTs;
     val nn = length Ts;
-    val get_indices = K [];
     val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
-    val callssss = map (fn fp_sugar0 => indexify_callsss fp_sugar0 []) fp_sugars0;
 
     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
       if nn > nn_fp then
@@ -142,7 +163,7 @@
           val notes =
             [(foldN, fold_thmss, []),
              (inductN, map single induct_thms, induct_attrs),
-             (recN, rec_thmss, [])]
+             (recN, rec_thmss, code_nitpicksimp_simp_attrs)]
             |> filter_out (null o #2)
             |> maps (fn (thmN, thmss, attrs) =>
               if forall null thmss then