diff -r bdfb607543f4 -r 8609527278f2 src/HOL/Tools/BNF/bnf_lfp_compat.ML --- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 17:18:28 2014 +0100 +++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML Fri Feb 14 18:42:43 2014 +0100 @@ -19,12 +19,6 @@ open BNF_FP_Def_Sugar open BNF_FP_N2M_Sugar -fun dtyp_of_typ _ (TFree a) = Datatype_Aux.DtTFree a - | dtyp_of_typ recTs (T as Type (s, Ts)) = - (case find_index (curry (op =) T) recTs of - ~1 => Datatype_Aux.DtType (s, map (dtyp_of_typ recTs) Ts) - | kk => Datatype_Aux.DtRec kk); - val compatN = "compat_"; val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: @{attributes [nitpick_simp, simp]}; @@ -56,7 +50,7 @@ val As = map2 (resort_tfree o Type.sort_of_atyp) var_As unsorted_As; val fpTs as fpT1 :: _ = map (fn s => Type (s, As)) fpT_names'; - fun nested_Tindicessss_of parent_Tkks (T as Type (s, _)) kk = + fun nested_Tparentss_indicessss_of parent_Tkks (T as Type (s, _)) kk = (case try lfp_sugar_of s of SOME ({T = T0, fp_res = {Ts = mutual_Ts0, ...}, ctr_sugars, ...}) => let @@ -67,42 +61,50 @@ val mutual_kks = kk upto kk + mutual_nn - 1; val mutual_Tkks = mutual_Ts ~~ mutual_kks; - fun Tindices_of_ctr_arg parent_Tkks (U as Type (s, _)) (accum as (Tkssss, kk')) = + fun indices_of_ctr_arg parent_Tkks (U as Type (s, Us)) (accum as (Tparents_ksss, kk')) = if s = @{type_name fun} then if exists_subtype_in mutual_Ts U then (warning "Incomplete support for recursion through functions -- \ \the old 'primrec' will fail"; - Tindices_of_ctr_arg parent_Tkks (range_type U) accum) + indices_of_ctr_arg parent_Tkks (range_type U) accum) else ([], accum) else (case AList.lookup (op =) (parent_Tkks @ mutual_Tkks) U of SOME kk => ([kk], accum) | NONE => - if exists_subtype_in mutual_Ts U then - ([kk'], nested_Tindicessss_of parent_Tkks U kk' |>> append Tkssss) + if exists (exists_strict_subtype_in mutual_Ts) Us then + error "Deeply nested recursion not supported" + else if exists (member (op =) mutual_Ts) Us then + ([kk'], + nested_Tparentss_indicessss_of parent_Tkks U kk' |>> append Tparents_ksss) else ([], accum)) - | Tindices_of_ctr_arg _ _ accum = ([], accum); + | indices_of_ctr_arg _ _ accum = ([], accum); - fun Tindicesss_of_mutual_type T kk ctr_Tss = - fold_map (fold_map (Tindices_of_ctr_arg ((T, kk) :: parent_Tkks))) ctr_Tss - #>> pair T; + fun Tparents_indicesss_of_mutual_type T kk ctr_Tss = + let val parent_Tkks' = (T, kk) :: parent_Tkks in + fold_map (fold_map (indices_of_ctr_arg parent_Tkks')) ctr_Tss + #>> pair parent_Tkks' + end; val ctrss = map #ctrs ctr_sugars; val ctr_Tsss = map (map (binder_types o substT o fastype_of)) ctrss; in ([], kk + mutual_nn) - |> fold_map3 Tindicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss - |> (fn (Tkkssss, (Tkkssss', kk)) => (Tkkssss @ Tkkssss', kk)) + |> fold_map3 Tparents_indicesss_of_mutual_type mutual_Ts mutual_kks ctr_Tsss + |> (fn (Tparentss_kkssss, (Tparentss_kkssss', kk)) => + (Tparentss_kkssss @ Tparentss_kkssss', kk)) end | NONE => error ("Unsupported recursion via type constructor " ^ quote s ^ " not corresponding to new-style datatype (cf. \"datatype_new\")")); fun get_indices (Var ((_, kk), _)) = [kk]; - val (Tkkssss, _) = nested_Tindicessss_of [] fpT1 0; - val Ts = map fst Tkkssss; + val (Tparentss_kkssss, _) = nested_Tparentss_indicessss_of [] fpT1 0; + val Tparentss = map fst Tparentss_kkssss; + val Ts = map (fst o hd) Tparentss; + val kkssss = map snd Tparentss_kkssss; val fp_sugars0 = map (lfp_sugar_of o fst o dest_Type) Ts; val ctrss0 = map (#ctrs o of_fp_sugar #ctr_sugars) fp_sugars0; @@ -113,14 +115,14 @@ (Var ((Name.uu, kk), @{typ "unit => unit"})); val callssss = - map2 (map2 (map2 (fn kks => fn ctr_T => - map (apply_comps (num_binder_types ctr_T)) kks)) o snd) - Tkkssss ctr_Tsss0; + map2 (map2 (map2 (fn kks => fn ctr_T => map (apply_comps (num_binder_types ctr_T)) kks))) + kkssss ctr_Tsss0; val b_names = Name.variant_list [] (map base_name_of_typ Ts); val compat_b_names = map (prefix compatN) b_names; val compat_bs = map Binding.name compat_b_names; val common_name = compatN ^ mk_common_name b_names; + val nn_fp = length fpTs; val nn = length Ts; @@ -134,13 +136,18 @@ co_iter_thmsss = iter_thmsss, ...} :: _ = fp_sugars; val inducts = map the_single inductss; - val mk_dtyp = dtyp_of_typ Ts; + fun mk_dtyp [] (TFree a) = Datatype_Aux.DtTFree a + | mk_dtyp [] (Type (s, Ts)) = Datatype_Aux.DtType (s, map (mk_dtyp []) Ts) + | mk_dtyp [kk] (Type (@{type_name fun}, [T, T'])) = + Datatype_Aux.DtType (@{type_name fun}, [mk_dtyp [] T, mk_dtyp [kk] T']) + | mk_dtyp [kk] T = if nth Ts kk = T then Datatype_Aux.DtRec kk else mk_dtyp [] T; - fun mk_ctr_descr Ts = mk_ctr Ts #> dest_Const ##> (binder_types #> map mk_dtyp); - fun mk_typ_descr index (Type (T_name, Ts)) ({ctrs, ...} : ctr_sugar) = - (index, (T_name, map mk_dtyp Ts, map (mk_ctr_descr Ts) ctrs)); + fun mk_ctr_descr Ts kkss ctr0 = + mk_ctr Ts ctr0 |> (fn Const (s, T) => (s, map2 mk_dtyp kkss (binder_types T))); + fun mk_typ_descr kksss ((Type (T_name, Ts), kk) :: parents) ctrs0 = + (kk, (T_name, map (mk_dtyp (map snd (take 1 parents))) Ts, map2 (mk_ctr_descr Ts) kksss ctrs0)); - val descr = map3 mk_typ_descr (0 upto nn - 1) Ts ctr_sugars; + val descr = map3 mk_typ_descr kkssss Tparentss ctrss0; val recs = map (fst o dest_Const o co_rec_of) co_iterss; val rec_thms = flat (map co_rec_of iter_thmsss);