src/HOL/Tools/BNF/bnf_lfp_compat.ML
changeset 55486 8609527278f2
parent 55485 bdfb607543f4
child 55531 601ca8efa000
equal deleted inserted replaced
55485:bdfb607543f4 55486:8609527278f2
    17 open BNF_Util
    17 open BNF_Util
    18 open BNF_FP_Util
    18 open BNF_FP_Util
    19 open BNF_FP_Def_Sugar
    19 open BNF_FP_Def_Sugar
    20 open BNF_FP_N2M_Sugar
    20 open BNF_FP_N2M_Sugar
    21 
    21 
    22 fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a
       
    23   | dtyp_of_typ recTs (T as Type (s, Ts)) =
       
    24     (case find_index (curry (op =) T) recTs of
       
    25       ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts)
       
    26     | kk => Datatype_Aux.DtRec kk);
       
    27 
       
    28 val compatN = "compat_";
    22 val compatN = "compat_";
    29 
    23 
    30 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
    24 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]};
    31 
    25 
    32 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
    26 (* TODO: graceful failure for local datatypes -- perhaps by making the command global *)
    54 
    48 
    55     val (unsorted_As, _) = lthy |> mk_TFrees (length var_As);
    49     val (unsorted_As, _) = lthy |> mk_TFrees (length var_As);
    56     val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
    50     val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As;
    57     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
    51     val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names';
    58 
    52 
    59     fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk =
    53     fun nested_Tparentss_indicessss_of parent_Tkks (T as Type (s, _)) kk =
    60       (case try lfp_sugar_of s of
    54       (case try lfp_sugar_of s of
    61         SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
    55         SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) =>
    62         let
    56         let
    63           val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
    57           val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (T0, T) Vartab.empty) [];
    64           val substT = Term.typ_subst_TVars rho;
    58           val substT = Term.typ_subst_TVars rho;
    65           val mutual_Ts = map substT mutual_Ts0;
    59           val mutual_Ts = map substT mutual_Ts0;
    66           val mutual_nn = length mutual_Ts;
    60           val mutual_nn = length mutual_Ts;
    67           val mutual_kks = kk upto kk + mutual_nn - 1;
    61           val mutual_kks = kk upto kk + mutual_nn - 1;
    68           val mutual_Tkks = mutual_Ts ~~ mutual_kks;
    62           val mutual_Tkks = mutual_Ts ~~ mutual_kks;
    69 
    63 
    70           fun Tindices_of_ctr_arg parent_Tkks (U as Type (s, _)) (accum as (Tkssss, kk')) =
    64           fun indices_of_ctr_arg parent_Tkks (U as Type (s, Us)) (accum as (Tparents_ksss, kk')) =
    71               if s = @{type_name fun} then
    65               if s = @{type_name fun} then
    72                 if exists_subtype_in mutual_Ts U then
    66                 if exists_subtype_in mutual_Ts U then
    73                   (warning "Incomplete support for recursion through functions -- \
    67                   (warning "Incomplete support for recursion through functions -- \
    74                      \the old 'primrec' will fail";
    68                      \the old 'primrec' will fail";
    75                    Tindices_of_ctr_arg parent_Tkks (range_type U) accum)
    69                    indices_of_ctr_arg parent_Tkks (range_type U) accum)
    76                 else
    70                 else
    77                   ([], accum)
    71                   ([], accum)
    78               else
    72               else
    79                 (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
    73                 (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of
    80                   SOME kk => ([kk], accum)
    74                   SOME kk => ([kk], accum)
    81                 | NONE =>
    75                 | NONE =>
    82                   if exists_subtype_in mutual_Ts U then
    76                   if exists (exists_strict_subtype_in mutual_Ts) Us then
    83                     ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss)
    77                     error "Deeply nested recursion not supported"
       
    78                   else if exists (member (op =) mutual_Ts) Us then
       
    79                     ([kk'],
       
    80                      nested_Tparentss_indicessss_of parent_Tkks U kk' |>> append Tparents_ksss)
    84                   else
    81                   else
    85                     ([], accum))
    82                     ([], accum))
    86             | Tindices_of_ctr_arg _ _ accum = ([], accum);
    83             | indices_of_ctr_arg _ _ accum = ([], accum);
    87 
    84 
    88           fun Tindicesss_of_mutual_type T kk ctr_Tss =
    85           fun Tparents_indicesss_of_mutual_type T kk ctr_Tss =
    89             fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss
    86             let val parent_Tkks' = (T, kk) :: parent_Tkks in
    90             #>> pair T;
    87               fold_map (fold_map (indices_of_ctr_arg parent_Tkks')) ctr_Tss
       
    88               #>> pair parent_Tkks'
       
    89             end;
    91 
    90 
    92           val ctrss = map #ctrs ctr_sugars;
    91           val ctrss = map #ctrs ctr_sugars;
    93           val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
    92           val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss;
    94         in
    93         in
    95           ([], kk + mutual_nn)
    94           ([], kk + mutual_nn)
    96           |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
    95           |> fold_map3 Tparents_indicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss
    97           |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk))
    96           |> (fn (Tparentss_kkssss, (Tparentss_kkssss', kk)) =>
       
    97             (Tparentss_kkssss @ Tparentss_kkssss', kk))
    98         end
    98         end
    99       | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
    99       | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^
   100           " not corresponding to new-style datatype (cf. \"datatype_new\")"));
   100           " not corresponding to new-style datatype (cf. \"datatype_new\")"));
   101 
   101 
   102     fun get_indices (Var ((_, kk), _)) = [kk];
   102     fun get_indices (Var ((_, kk), _)) = [kk];
   103 
   103 
   104     val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0;
   104     val (Tparentss_kkssss, _) = nested_Tparentss_indicessss_of [] fpT1 0;
   105     val Ts = map fst Tkkssss;
   105     val Tparentss = map fst Tparentss_kkssss;
       
   106     val Ts = map (fst o hd) Tparentss;
       
   107     val kkssss = map snd Tparentss_kkssss;
   106 
   108 
   107     val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
   109     val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts;
   108     val ctrss0 = map (#ctrs o of_fp_sugar #ctr_sugars) fp_sugars0;
   110     val ctrss0 = map (#ctrs o of_fp_sugar #ctr_sugars) fp_sugars0;
   109     val ctr_Tsss0 = map (map (binder_types o fastype_of)) ctrss0;
   111     val ctr_Tsss0 = map (map (binder_types o fastype_of)) ctrss0;
   110 
   112 
   111     fun apply_comps n kk =
   113     fun apply_comps n kk =
   112       mk_partial_compN n (replicate n @{typ unit} ---> @{typ unit})
   114       mk_partial_compN n (replicate n @{typ unit} ---> @{typ unit})
   113         (Var ((Name.uu, kk), @{typ "unit => unit"}));
   115         (Var ((Name.uu, kk), @{typ "unit => unit"}));
   114 
   116 
   115     val callssss =
   117     val callssss =
   116       map2 (map2 (map2 (fn kks => fn ctr_T =>
   118       map2 (map2 (map2 (fn kks => fn ctr_T => map (apply_comps (num_binder_types ctr_T)) kks)))
   117           map (apply_comps (num_binder_types ctr_T)) kks)) o snd)
   119         kkssss ctr_Tsss0;
   118         Tkkssss ctr_Tsss0;
       
   119 
   120 
   120     val b_names = Name.variant_list [] (map base_name_of_typ Ts);
   121     val b_names = Name.variant_list [] (map base_name_of_typ Ts);
   121     val compat_b_names = map (prefix compatN) b_names;
   122     val compat_b_names = map (prefix compatN) b_names;
   122     val compat_bs = map Binding.name compat_b_names;
   123     val compat_bs = map Binding.name compat_b_names;
   123     val common_name = compatN ^ mk_common_name b_names;
   124     val common_name = compatN ^ mk_common_name b_names;
       
   125 
   124     val nn_fp = length fpTs;
   126     val nn_fp = length fpTs;
   125     val nn = length Ts;
   127     val nn = length Ts;
   126 
   128 
   127     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
   129     val ((fp_sugars, (lfp_sugar_thms, _)), lthy) =
   128       if nn > nn_fp then
   130       if nn > nn_fp then
   132 
   134 
   133     val {ctr_sugars, co_inducts = [induct], co_inductss = inductss, co_iterss,
   135     val {ctr_sugars, co_inducts = [induct], co_inductss = inductss, co_iterss,
   134       co_iter_thmsss = iter_thmsss, ...} :: _ = fp_sugars;
   136       co_iter_thmsss = iter_thmsss, ...} :: _ = fp_sugars;
   135     val inducts = map the_single inductss;
   137     val inducts = map the_single inductss;
   136 
   138 
   137     val mk_dtyp = dtyp_of_typ Ts;
   139     fun mk_dtyp [] (TFree a) = Datatype_Aux.DtTFree a
   138 
   140       | mk_dtyp [] (Type (s, Ts)) = Datatype_Aux.DtType (s, map (mk_dtyp []) Ts)
   139     fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp);
   141       | mk_dtyp [kk] (Type (@{type_name fun}, [T, T'])) =
   140     fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) =
   142         Datatype_Aux.DtType (@{type_name fun}, [mk_dtyp [] T, mk_dtyp [kk] T'])
   141       (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs));
   143       | mk_dtyp [kk] T = if nth Ts kk = T then Datatype_Aux.DtRec kk else mk_dtyp [] T;
   142 
   144 
   143     val descr = map3 mk_typ_descr (0 upto nn - 1) Ts ctr_sugars;
   145     fun mk_ctr_descr Ts kkss ctr0 =
       
   146       mk_ctr Ts ctr0 |> (fn Const (s, T) => (s, map2 mk_dtyp kkss (binder_types T)));
       
   147     fun mk_typ_descr kksss ((Type (T_name, Ts), kk) :: parents) ctrs0 =
       
   148       (kk, (T_name, map (mk_dtyp (map snd (take 1 parents))) Ts, map2 (mk_ctr_descr Ts) kksss ctrs0));
       
   149 
       
   150     val descr = map3 mk_typ_descr kkssss Tparentss ctrss0;
   144     val recs = map (fst o dest_Const o co_rec_of) co_iterss;
   151     val recs = map (fst o dest_Const o co_rec_of) co_iterss;
   145     val rec_thms = flat (map co_rec_of iter_thmsss);
   152     val rec_thms = flat (map co_rec_of iter_thmsss);
   146 
   153 
   147     fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
   154     fun mk_info ({T = Type (T_name0, _), index, ...} : fp_sugar) =
   148       let
   155       let