allow different functions to recurse on the same type, like in the old package
authorblanchet
Fri Feb 14 15:03:24 2014 +0100 (2014-02-14)
changeset 5548059cc4a8bc28a
parent 55479 ece4910c3ea0
child 55481 a8b83356e869
allow different functions to recurse on the same type, like in the old package
src/HOL/Tools/BNF/bnf_comp.ML
src/HOL/Tools/BNF/bnf_def.ML
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_compat.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML
     1.1 --- a/src/HOL/Tools/BNF/bnf_comp.ML	Fri Feb 14 15:03:23 2014 +0100
     1.2 +++ b/src/HOL/Tools/BNF/bnf_comp.ML	Fri Feb 14 15:03:24 2014 +0100
     1.3 @@ -449,8 +449,9 @@
     1.4      val live = live_of_bnf bnf;
     1.5      val dead = dead_of_bnf bnf;
     1.6      val nwits = nwits_of_bnf bnf;
     1.7 -    fun permute xs = permute_like (op =) src dest xs;
     1.8 -    fun unpermute xs = permute_like (op =) dest src xs;
     1.9 +
    1.10 +    fun permute xs = permute_like_unique (op =) src dest xs;
    1.11 +    fun unpermute xs = permute_like_unique (op =) dest src xs;
    1.12  
    1.13      val (Ds, lthy1) = apfst (map TFree)
    1.14        (Variable.invent_types (replicate dead HOLogic.typeS) lthy);
     2.1 --- a/src/HOL/Tools/BNF/bnf_def.ML	Fri Feb 14 15:03:23 2014 +0100
     2.2 +++ b/src/HOL/Tools/BNF/bnf_def.ML	Fri Feb 14 15:03:24 2014 +0100
     2.3 @@ -352,7 +352,7 @@
     2.4      val lives = lives_of_bnf bnf;
     2.5      val deads = deads_of_bnf bnf;
     2.6    in
     2.7 -    permute_like (op =) (deads @ lives) Ts (replicate (length deads) dead_x @ xs)
     2.8 +    permute_like_unique (op =) (deads @ lives) Ts (replicate (length deads) dead_x @ xs)
     2.9    end;
    2.10  
    2.11  (*terms*)
    2.12 @@ -541,7 +541,7 @@
    2.13      val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
    2.14      val flat_fs' = map_args flat_fs;
    2.15    in
    2.16 -    permute_like (op aconv) flat_fs fs flat_fs'
    2.17 +    permute_like_unique (op aconv) flat_fs fs flat_fs'
    2.18    end;
    2.19  
    2.20  
     3.1 --- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
     3.2 +++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
     3.3 @@ -65,7 +65,7 @@
     3.4       * (typ list list * typ list list list list * term list list
     3.5          * term list list list list) list option
     3.6       * (string * term list * term list list
     3.7 -        * ((term list list * term list list list) * (typ list * typ list list)) list) option)
     3.8 +        * ((term list list * term list list list) * typ list) list) option)
     3.9      * Proof.context
    3.10    val mk_iter_fun_arg_types: typ list list list -> int list -> int list list -> term ->
    3.11      typ list list list list
    3.12 @@ -77,7 +77,7 @@
    3.13      (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
    3.14      (term list * thm list) * Proof.context
    3.15    val define_coiters: string list -> string * term list * term list list
    3.16 -    * ((term list list * term list list list) * (typ list * typ list list)) list ->
    3.17 +    * ((term list list * term list list list) * typ list) list ->
    3.18      (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
    3.19      (term list * thm list) * Proof.context
    3.20    val derive_induct_iters_thms_for_types: BNF_Def.bnf list ->
    3.21 @@ -87,7 +87,7 @@
    3.22      thm list list -> local_theory -> lfp_sugar_thms
    3.23    val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list ->
    3.24      string * term list * term list list * ((term list list * term list list list)
    3.25 -      * (typ list * typ list list)) list ->
    3.26 +      * typ list) list ->
    3.27      thm -> thm list -> thm list -> thm list list -> BNF_Def.bnf list -> typ list -> typ list ->
    3.28      typ list -> typ list list list -> int list list -> int list list -> int list -> thm list list ->
    3.29      Ctr_Sugar.ctr_sugar list -> term list list -> thm list list -> (thm list -> thm list) ->
    3.30 @@ -443,9 +443,8 @@
    3.31        let
    3.32          val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
    3.33          val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts;
    3.34 -        val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
    3.35        in
    3.36 -        (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
    3.37 +        (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
    3.38        end;
    3.39  
    3.40      val (r_Tssss, g_Tsss, g_Tssss, unfold_types) = mk_types un_fold_of;
    3.41 @@ -537,7 +536,7 @@
    3.42    let
    3.43      val nn = length fpTs;
    3.44  
    3.45 -    val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of (hd ctor_iters)));
    3.46 +    val Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of (hd ctor_iters)));
    3.47  
    3.48      fun generate_iter pre (_, _, fss, xssss) ctor_iter =
    3.49        (mk_binding pre,
    3.50 @@ -552,9 +551,9 @@
    3.51    let
    3.52      val nn = length fpTs;
    3.53  
    3.54 -    val C_to_fpT as Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters)));
    3.55 +    val Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters)));
    3.56  
    3.57 -    fun generate_coiter pre ((pfss, cqfsss), (f_sum_prod_Ts, pf_Tss)) dtor_coiter =
    3.58 +    fun generate_coiter pre ((pfss, cqfsss), f_sum_prod_Ts) dtor_coiter =
    3.59        (mk_binding pre,
    3.60         fold_rev (fold_rev Term.lambda) pfss (Term.list_comb (dtor_coiter,
    3.61           map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss)));
     4.1 --- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
     4.2 +++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
     4.3 @@ -147,40 +147,38 @@
     4.4      fun check_call_dead live_call call =
     4.5        if null (get_indices call) then () else incompatible_calls [live_call, call];
     4.6  
     4.7 -    fun freeze_fpTs_default (T as Type (s, Ts)) =
     4.8 -        (case find_index (curry (op =) T) fpTs of
     4.9 -          ~1 => Type (s, map freeze_fpTs_default Ts)
    4.10 -        | kk => nth Xs kk)
    4.11 -      | freeze_fpTs_default T = T;
    4.12 +    fun freeze_fpTs_type_based_default (T as Type (s, Ts)) =
    4.13 +        (case filter (curry (op =) T o snd) (map_index I fpTs) of
    4.14 +          [(kk, _)] => nth Xs kk
    4.15 +        | _ => Type (s, map freeze_fpTs_type_based_default Ts))
    4.16 +      | freeze_fpTs_type_based_default T = T;
    4.17  
    4.18 -    fun freeze_fpTs_simple calls (T as Type (s, Ts)) =
    4.19 -        (case fold (union (op =)) (map get_indices calls) [] of
    4.20 -          [] => freeze_fpTs_default T
    4.21 -        | [kk] => nth Xs kk
    4.22 -        | _ => incompatible_calls calls)
    4.23 -      | freeze_fpTs_simple _ T = T;
    4.24 +    fun freeze_fpTs_mutual_call calls T =
    4.25 +      (case fold (union (op =)) (map get_indices calls) [] of
    4.26 +        [] => freeze_fpTs_type_based_default T
    4.27 +      | [kk] => nth Xs kk
    4.28 +      | _ => incompatible_calls calls);
    4.29  
    4.30      fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
    4.31 -        (T as Type (s, Ts)) =
    4.32 +        (Type (s, Ts)) =
    4.33        if Ts' = Ts then
    4.34          nested_self_call live_call
    4.35        else
    4.36          (List.app (check_call_dead live_call) dead_calls;
    4.37 -         Type (s, map2 (freeze_fpTs fpT) (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
    4.38 -           (transpose callss)) Ts))
    4.39 -    and freeze_fpTs fpT calls (T as Type (s, _)) =
    4.40 +         Type (s, map2 (freeze_fpTs_call fpT)
    4.41 +           (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) [] (transpose callss)) Ts))
    4.42 +    and freeze_fpTs_call fpT calls (T as Type (s, _)) =
    4.43          (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
    4.44            ([], _) =>
    4.45            (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
    4.46 -            ([], _) => freeze_fpTs_simple calls T
    4.47 +            ([], _) => freeze_fpTs_mutual_call calls T
    4.48            | callsp => freeze_fpTs_map fpT callsp T)
    4.49          | callsp => freeze_fpTs_map fpT callsp T)
    4.50 -      | freeze_fpTs _ _ T = T;
    4.51 +      | freeze_fpTs_call _ _ T = T;
    4.52  
    4.53      val ctr_Tsss = map (map binder_types) ctr_Tss;
    4.54 -    val ctrXs_Tsss = map3 (map2 o map2 o freeze_fpTs) fpTs callssss ctr_Tsss;
    4.55 +    val ctrXs_Tsss = map3 (map2 o map2 o freeze_fpTs_call) fpTs callssss ctr_Tsss;
    4.56      val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
    4.57 -    val ctr_Ts = map (body_type o hd) ctr_Tss;
    4.58  
    4.59      val ns = map length ctr_Tsss;
    4.60      val kss = map (fn n => 1 upto n) ns;
    4.61 @@ -270,6 +268,8 @@
    4.62  
    4.63  fun retypargs tyargs (Type (s, _)) = Type (s, tyargs);
    4.64  
    4.65 +fun exists_strict_subtype_in Ts T = exists_subtype_in (filter_out (curry (op =) T) Ts) T;
    4.66 +
    4.67  fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
    4.68      f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
    4.69    | fold_subtype_pairs f TU = f TU;
    4.70 @@ -279,7 +279,6 @@
    4.71      val qsoty = quote o Syntax.string_of_typ lthy;
    4.72      val qsotys = space_implode " or " o map qsoty;
    4.73  
    4.74 -    fun duplicate_datatype T = error (qsoty T ^ " is not mutually recursive with itself");
    4.75      fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
    4.76      fun not_co_datatype (T as Type (s, _)) =
    4.77          if fp = Least_FP andalso
    4.78 @@ -290,11 +289,10 @@
    4.79        | not_co_datatype T = not_co_datatype0 T;
    4.80      fun not_mutually_nested_rec Ts1 Ts2 =
    4.81        error (qsotys Ts1 ^ " is neither mutually recursive with " ^ qsotys Ts2 ^
    4.82 -        " nor nested recursive via " ^ qsotys Ts2);
    4.83 +        " nor nested recursive through " ^
    4.84 +        (if Ts1 = Ts2 andalso can the_single Ts1 then "itself" else qsotys Ts2));
    4.85  
    4.86 -    val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
    4.87 -
    4.88 -    val perm_actual_Ts =
    4.89 +    val sorted_actual_Ts =
    4.90        sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
    4.91  
    4.92      fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
    4.93 @@ -323,7 +321,7 @@
    4.94            val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
    4.95            val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
    4.96  
    4.97 -          val _ = seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
    4.98 +          val _ = seen = [] orelse exists (exists_strict_subtype_in seen) mutual_Ts orelse
    4.99              not_mutually_nested_rec mutual_Ts seen;
   4.100  
   4.101            fun fresh_tyargs () =
   4.102 @@ -354,17 +352,18 @@
   4.103                fresh_tyargs ();
   4.104  
   4.105            val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
   4.106 -          val Ts' = filter_out (member (op =) mutual_Ts) Ts;
   4.107 +          val other_mutual_Ts = remove1 (op =) T mutual_Ts;
   4.108 +          val Ts' = fold (remove1 (op =)) other_mutual_Ts Ts;
   4.109          in
   4.110            gather_types lthy' rho' (num_groups + 1) (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts)
   4.111              Ts'
   4.112          end
   4.113        | gather_types _ _ _ _ _ (T :: _) = not_co_datatype T;
   4.114  
   4.115 -    val (num_groups, perm_Ts, perm_gen_Ts) = gather_types lthy [] 0 [] [] perm_actual_Ts;
   4.116 +    val (num_groups, perm_Ts, perm_gen_Ts) = gather_types lthy [] 0 [] [] sorted_actual_Ts;
   4.117      val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
   4.118  
   4.119 -    val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
   4.120 +    val missing_Ts = subtract (op =) actual_Ts perm_Ts;
   4.121      val Ts = actual_Ts @ missing_Ts;
   4.122  
   4.123      val nn = length Ts;
     5.1 --- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
     5.2 +++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
     5.3 @@ -407,8 +407,8 @@
     5.4      val fun_arg_hs =
     5.5        flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
     5.6  
     5.7 -    fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
     5.8 -    fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
     5.9 +    fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
    5.10 +    fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
    5.11  
    5.12      val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
    5.13  
     6.1 --- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 15:03:23 2014 +0100
     6.2 +++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 15:03:24 2014 +0100
     6.3 @@ -67,8 +67,7 @@
     6.4            val mutual_kks = kk upto kk + mutual_nn - 1;
     6.5            val mutual_Tkks = mutual_Ts ~~ mutual_kks;
     6.6  
     6.7 -          fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _))
     6.8 -                (accum as (Tkssss, kk')) =
     6.9 +          fun Tindices_of_ctr_arg parent_Tkks (U as Type (s, _)) (accum as (Tkssss, kk')) =
    6.10                if s = @{type_name fun} then
    6.11                  if exists_subtype_in mutual_Ts U then
    6.12                    (warning "Incomplete support for recursion through functions -- \
     7.1 --- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
     7.2 +++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
     7.3 @@ -162,8 +162,8 @@
     7.4      val perm_fun_arg_Tssss =
     7.5        mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
     7.6  
     7.7 -    fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
     7.8 -    fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
     7.9 +    fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
    7.10 +    fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
    7.11  
    7.12      val induct_thms = unpermute0 (conj_dests nn induct_thm);
    7.13  
     8.1 --- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
     8.2 +++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
     8.3 @@ -195,7 +195,7 @@
     8.4  val code_nitpicksimp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs;
     8.5  val code_nitpicksimp_simp_attrs = code_nitpicksimp_attrs @ simp_attrs;
     8.6  
     8.7 -fun unflat_lookup eq xs ys = map (fn xs' => permute_like eq xs xs' ys);
     8.8 +fun unflat_lookup eq xs ys = map (fn xs' => permute_like_unique eq xs xs' ys);
     8.9  
    8.10  fun mk_half_pairss' _ ([], []) = []
    8.11    | mk_half_pairss' indent (x :: xs, _ :: ys) =
     9.1 --- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML	Fri Feb 14 15:03:23 2014 +0100
     9.2 +++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML	Fri Feb 14 15:03:24 2014 +0100
     9.3 @@ -19,7 +19,8 @@
     9.4    val transpose: 'a list list -> 'a list list
     9.5    val pad_list: 'a -> int -> 'a list -> 'a list
     9.6    val splice: 'a list -> 'a list -> 'a list
     9.7 -  val permute_like: ('a * 'b -> bool) -> 'a list -> 'b list -> 'c list -> 'c list
     9.8 +  val permute_like_unique: ('a * 'b -> bool) -> 'a list -> 'b list -> 'c list -> 'c list
     9.9 +  val permute_like: ('a * 'a -> bool) -> 'a list -> 'a list -> 'b list -> 'b list
    9.10  
    9.11    val mk_names: int -> string -> string list
    9.12    val mk_fresh_names: Proof.context -> int -> string -> string list * Proof.context
    9.13 @@ -129,7 +130,18 @@
    9.14  
    9.15  fun splice xs ys = flat (map2 (fn x => fn y => [x, y]) xs ys);
    9.16  
    9.17 -fun permute_like eq xs xs' ys = map (nth ys o (fn y => find_index (fn x => eq (x, y)) xs)) xs';
    9.18 +fun permute_like_unique eq xs xs' ys =
    9.19 +  map (nth ys o (fn y => find_index (fn x => eq (x, y)) xs)) xs';
    9.20 +
    9.21 +fun fresh eq x names =
    9.22 +  (case AList.lookup eq names x of
    9.23 +    NONE => ((x, 0), (x, 0) :: names)
    9.24 +  | SOME n => ((x, n + 1), AList.update eq (x, n + 1) names));
    9.25 +
    9.26 +fun deambiguate eq xs = fst (fold_map (fresh eq) xs []);
    9.27 +
    9.28 +fun permute_like eq xs xs' =
    9.29 +  permute_like_unique (eq_pair eq (op =)) (deambiguate eq xs) (deambiguate eq xs');
    9.30  
    9.31  fun mk_names n x = if n = 1 then [x] else map (fn i => x ^ string_of_int i) (1 upto n);
    9.32  fun mk_fresh_names ctxt = (fn xs => Variable.variant_fixes xs ctxt) oo mk_names;