support deeply nested datatypes in 'datatype_compat'
authorblanchet
Tue, 08 Apr 2014 17:49:03 +0200
changeset 56453 00548d372f02
parent 56452 0c98c9118407
child 56454 e9e82384e5a1
support deeply nested datatypes in 'datatype_compat'
src/HOL/Tools/BNF/bnf_lfp_compat.ML
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Tue Apr 08 13:47:27 2014 +0200
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Tue Apr 08 17:49:03 2014 +0200
@@ -23,6 +23,22 @@
 
 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
 
+fun reindex_desc desc =
+  let
+    val kks = map fst desc;
+    val perm_kks = sort int_ord kks;
+
+    fun perm_dtyp (Datatype_Aux.DtType (s, Ds)) = Datatype_Aux.DtType (s, map perm_dtyp Ds)
+      | perm_dtyp (Datatype_Aux.DtRec kk) = Datatype_Aux.DtRec (find_index (curry (op =) kk) kks)
+      | perm_dtyp D = D
+  in
+    if perm_kks = kks then
+      desc
+    else
+      perm_kks ~~
+      map (fn (_, (s, Ds, sDss)) => (s, map perm_dtyp Ds, map (apsnd (map perm_dtyp)) sDss)) desc
+  end
+
 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
 fun datatype_compat_cmd raw_fpT_names lthy =
   let
@@ -49,73 +65,33 @@
 
     val (As_names, _) = lthy |> Variable.variant_fixes (map (fn TVar ((s, _), _) => s) var_As);
     val As = map2 (fn s => fn TVar (_, S) => TFree (s, S)) As_names var_As;
-    val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
+    val fpTs = map (fn s => Type (s, As)) fpT_names';
+
+    val nn_fp = length fpTs;
 
-    fun nested_Tparentss_indicessss_of parent_Tkks (T as Type (s, _)) kk =
-      (case try lfp_sugar_of s of
-        SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ...}) =>
-        let
-          val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
-          val substT = Term.typ_subst_TVars rho;
-          val mutual_Ts = map substT mutual_Ts0;
-          val mutual_nn = length mutual_Ts;
-          val mutual_kks = kk upto kk + mutual_nn - 1;
-          val mutual_Tkks = mutual_Ts ~~ mutual_kks;
+    val mk_dtyp = Datatype_Aux.dtyp_of_typ (map (apsnd (map Term.dest_TFree) o dest_Type) fpTs);
+
+    fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp);
+    fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
+      (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
 
-          fun indices_of_ctr_arg parent_Tkks (U as Type (s, Us)) (accum as (Tparents_ksss, kk')) =
-              if s = @{type_name fun} then
-                if exists_subtype_in mutual_Ts U then
-                  (warning "Incomplete support for recursion through functions -- \
-                     \the old 'primrec' will fail";
-                   indices_of_ctr_arg parent_Tkks (range_type U) accum)
-                else
-                  ([], accum)
-              else
-                (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
-                  SOME kk => ([kk], accum)
-                | NONE =>
-                  if exists (exists_strict_subtype_in mutual_Ts) Us then
-                    error "Deeply nested recursion not supported"
-                  else if exists (member (op =) mutual_Ts) Us then
-                    ([kk'],
-                     nested_Tparentss_indicessss_of parent_Tkks U kk' |>> append Tparents_ksss)
-                  else
-                    ([], accum))
-            | indices_of_ctr_arg _ _ accum = ([], accum);
+    val orig_descr = map3 mk_typ_descr (0 upto nn_fp - 1) fpTs fp_ctr_sugars;
+    val all_infos = Datatype_Data.get_all thy;
+    val (orig_descr' :: nested_descrs, _) =
+      Datatype_Aux.unfold_datatypes lthy orig_descr all_infos orig_descr nn_fp;
 
-          fun Tparents_indicesss_of_mutual_type T kk ctr_Tss =
-            let val parent_Tkks' = (T, kk) :: parent_Tkks in
-              fold_map (fold_map (indices_of_ctr_arg parent_Tkks')) ctr_Tss
-              #>> pair parent_Tkks'
-            end;
+    (* put nested types before the types that nest them, as needed for N2M *)
+    val descr = reindex_desc (orig_descr' @ flat (rev nested_descrs));
 
-          val ctrss = map (#ctrs o #ctr_sugar o lfp_sugar_of o fst o dest_Type) mutual_Ts;
-          val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
-        in
-          ([], kk + mutual_nn)
-          |> fold_map3 Tparents_indicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
-          |> (fn (Tparentss_kkssss, (Tparentss_kkssss', kk)) =>
-            (Tparentss_kkssss @ Tparentss_kkssss', kk))
-        end
-      | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
-          " not corresponding to new-style datatype (cf. \"datatype_new\")"));
+    val dest_dtyp = Datatype_Aux.typ_of_dtyp descr;
 
-    val (Tparentss_kkssss, _) = nested_Tparentss_indicessss_of [] fpT1 0;
-    val Tparentss = map fst Tparentss_kkssss;
-    val Ts = map (fst o hd) Tparentss;
-    val kkssss = map snd Tparentss_kkssss;
+    val Ts = Datatype_Aux.get_rec_types descr;
+    val nn = length Ts;
 
     val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
-    val ctrss0 = map (#ctrs o #ctr_sugar) fp_sugars0;
-    val ctr_Tsss0 = map (map (binder_types o fastype_of)) ctrss0;
-
-    val b_names = Name.variant_list [] (map base_name_of_typ Ts);
-    val compat_b_names = map (prefix compatN) b_names;
-    val compat_bs = map Binding.name compat_b_names;
-    val common_name = compatN ^ mk_common_name b_names;
-
-    val nn_fp = length fpTs;
-    val nn = length Ts;
+    val ctr_Tsss = map (map (map dest_dtyp o snd) o #3 o snd) descr;
+    val kkssss =
+      map (map (map (fn Datatype_Aux.DtRec kk => [kk] | _ => []) o snd) o #3 o snd) descr;
 
     val callers = map (fn kk => Var ((Name.uu, kk), @{typ "unit => unit"})) (0 upto nn - 1);
 
@@ -123,8 +99,12 @@
       mk_partial_compN n (replicate n HOLogic.unitT ---> HOLogic.unitT) (nth callers kk);
 
     val callssss =
-      map2 (map2 (map2 (fn kks => fn ctr_T => map (apply_comps (num_binder_types ctr_T)) kks)))
-        kkssss ctr_Tsss0;
+      map2 (map2 (map2 (fn ctr_T => map (apply_comps (num_binder_types ctr_T))))) ctr_Tsss kkssss;
+
+    val b_names = Name.variant_list [] (map base_name_of_typ Ts);
+    val compat_b_names = map (prefix compatN) b_names;
+    val compat_bs = map Binding.name compat_b_names;
+    val common_name = compatN ^ mk_common_name b_names;
 
     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
       if nn > nn_fp then
@@ -132,24 +112,12 @@
       else
         ((fp_sugars0, (NONE, NONE)), lthy);
 
+    val recs = map (fst o dest_Const o #co_rec) fp_sugars;
+    val rec_thms = maps #co_rec_thms fp_sugars;
+
     val {common_co_inducts = [induct], ...} :: _ = fp_sugars;
     val inducts = map (the_single o #co_inducts) fp_sugars;
 
-    fun mk_dtyp [] (TFree a) = Datatype_Aux.DtTFree a
-      | mk_dtyp [] (Type (s, Ts)) = Datatype_Aux.DtType (s, map (mk_dtyp []) Ts)
-      | mk_dtyp [kk] (Type (@{type_name fun}, [T, T'])) =
-        Datatype_Aux.DtType (@{type_name fun}, [mk_dtyp [] T, mk_dtyp [kk] T'])
-      | mk_dtyp [kk] T = if nth Ts kk = T then Datatype_Aux.DtRec kk else mk_dtyp [] T;
-
-    fun mk_ctr_descr Ts kkss ctr0 =
-      mk_ctr Ts ctr0 |> (fn Const (s, T) => (s, map2 mk_dtyp kkss (binder_types T)));
-    fun mk_typ_descr kksss ((Type (T_name, Ts), kk) :: parents) ctrs0 =
-      (kk, (T_name, map (mk_dtyp (map snd (take 1 parents))) Ts, map2 (mk_ctr_descr Ts) kksss ctrs0));
-
-    val descr = map3 mk_typ_descr kkssss Tparentss ctrss0;
-    val recs = map (fst o dest_Const o #co_rec) fp_sugars;
-    val rec_thms = maps #co_rec_thms fp_sugars;
-
     fun mk_info (kk, {T = Type (T_name0, _), ctr_sugar = {casex, exhaust, nchotomy, injects,
         distincts, case_thms, case_cong, weak_case_cong, split, split_asm, ...}, ...} : fp_sugar) =
       (T_name0,