allow different functions to recurse on the same type, like in the old package
authorblanchet
Fri, 14 Feb 2014 15:03:24 +0100
changeset 55480 59cc4a8bc28a
parent 55479 ece4910c3ea0
child 55481 a8b83356e869
allow different functions to recurse on the same type, like in the old package
src/HOL/Tools/BNF/bnf_comp.ML
src/HOL/Tools/BNF/bnf_def.ML
src/HOL/Tools/BNF/bnf_fp_def_sugar.ML
src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML
src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML
src/HOL/Tools/BNF/bnf_lfp_compat.ML
src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML
src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML
--- a/src/HOL/Tools/BNF/bnf_comp.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_comp.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -449,8 +449,9 @@
     val live = live_of_bnf bnf;
     val dead = dead_of_bnf bnf;
     val nwits = nwits_of_bnf bnf;
-    fun permute xs = permute_like (op =) src dest xs;
-    fun unpermute xs = permute_like (op =) dest src xs;
+
+    fun permute xs = permute_like_unique (op =) src dest xs;
+    fun unpermute xs = permute_like_unique (op =) dest src xs;
 
     val (Ds, lthy1) = apfst (map TFree)
       (Variable.invent_types (replicate dead HOLogic.typeS) lthy);
--- a/src/HOL/Tools/BNF/bnf_def.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_def.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -352,7 +352,7 @@
     val lives = lives_of_bnf bnf;
     val deads = deads_of_bnf bnf;
   in
-    permute_like (op =) (deads @ lives) Ts (replicate (length deads) dead_x @ xs)
+    permute_like_unique (op =) (deads @ lives) Ts (replicate (length deads) dead_x @ xs)
   end;
 
 (*terms*)
@@ -541,7 +541,7 @@
     val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs;
     val flat_fs' = map_args flat_fs;
   in
-    permute_like (op aconv) flat_fs fs flat_fs'
+    permute_like_unique (op aconv) flat_fs fs flat_fs'
   end;
 
 
--- a/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_def_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -65,7 +65,7 @@
      * (typ list list * typ list list list list * term list list
         * term list list list list) list option
      * (string * term list * term list list
-        * ((term list list * term list list list) * (typ list * typ list list)) list) option)
+        * ((term list list * term list list list) * typ list) list) option)
     * Proof.context
   val mk_iter_fun_arg_types: typ list list list -> int list -> int list list -> term ->
     typ list list list list
@@ -77,7 +77,7 @@
     (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
     (term list * thm list) * Proof.context
   val define_coiters: string list -> string * term list * term list list
-    * ((term list list * term list list list) * (typ list * typ list list)) list ->
+    * ((term list list * term list list list) * typ list) list ->
     (string -> binding) -> typ list -> typ list -> term list -> Proof.context ->
     (term list * thm list) * Proof.context
   val derive_induct_iters_thms_for_types: BNF_Def.bnf list ->
@@ -87,7 +87,7 @@
     thm list list -> local_theory -> lfp_sugar_thms
   val derive_coinduct_coiters_thms_for_types: BNF_Def.bnf list ->
     string * term list * term list list * ((term list list * term list list list)
-      * (typ list * typ list list)) list ->
+      * typ list) list ->
     thm -> thm list -> thm list -> thm list list -> BNF_Def.bnf list -> typ list -> typ list ->
     typ list -> typ list list list -> int list list -> int list list -> int list -> thm list list ->
     Ctr_Sugar.ctr_sugar list -> term list list -> thm list list -> (thm list -> thm list) ->
@@ -443,9 +443,8 @@
       let
         val fun_Ts = map get_Ts dtor_coiter_fun_Tss;
         val (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts) = mk_coiter_fun_arg_types0 ctr_Tsss Cs ns fun_Ts;
-        val pf_Tss = map3 flat_corec_preds_predsss_gettersss p_Tss q_Tssss f_Tssss;
       in
-        (q_Tssss, f_Tsss, f_Tssss, (f_sum_prod_Ts, pf_Tss))
+        (q_Tssss, f_Tsss, f_Tssss, f_sum_prod_Ts)
       end;
 
     val (r_Tssss, g_Tsss, g_Tssss, unfold_types) = mk_types un_fold_of;
@@ -537,7 +536,7 @@
   let
     val nn = length fpTs;
 
-    val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of (hd ctor_iters)));
+    val Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of (hd ctor_iters)));
 
     fun generate_iter pre (_, _, fss, xssss) ctor_iter =
       (mk_binding pre,
@@ -552,9 +551,9 @@
   let
     val nn = length fpTs;
 
-    val C_to_fpT as Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters)));
+    val Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of (hd dtor_coiters)));
 
-    fun generate_coiter pre ((pfss, cqfsss), (f_sum_prod_Ts, pf_Tss)) dtor_coiter =
+    fun generate_coiter pre ((pfss, cqfsss), f_sum_prod_Ts) dtor_coiter =
       (mk_binding pre,
        fold_rev (fold_rev Term.lambda) pfss (Term.list_comb (dtor_coiter,
          map4 mk_preds_getterss_join cs cpss f_sum_prod_Ts cqfsss)));
--- a/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_fp_n2m_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -147,40 +147,38 @@
     fun check_call_dead live_call call =
       if null (get_indices call) then () else incompatible_calls [live_call, call];
 
-    fun freeze_fpTs_default (T as Type (s, Ts)) =
-        (case find_index (curry (op =) T) fpTs of
-          ~1 => Type (s, map freeze_fpTs_default Ts)
-        | kk => nth Xs kk)
-      | freeze_fpTs_default T = T;
+    fun freeze_fpTs_type_based_default (T as Type (s, Ts)) =
+        (case filter (curry (op =) T o snd) (map_index I fpTs) of
+          [(kk, _)] => nth Xs kk
+        | _ => Type (s, map freeze_fpTs_type_based_default Ts))
+      | freeze_fpTs_type_based_default T = T;
 
-    fun freeze_fpTs_simple calls (T as Type (s, Ts)) =
-        (case fold (union (op =)) (map get_indices calls) [] of
-          [] => freeze_fpTs_default T
-        | [kk] => nth Xs kk
-        | _ => incompatible_calls calls)
-      | freeze_fpTs_simple _ T = T;
+    fun freeze_fpTs_mutual_call calls T =
+      (case fold (union (op =)) (map get_indices calls) [] of
+        [] => freeze_fpTs_type_based_default T
+      | [kk] => nth Xs kk
+      | _ => incompatible_calls calls);
 
     fun freeze_fpTs_map (fpT as Type (_, Ts')) (callss, (live_call :: _, dead_calls))
-        (T as Type (s, Ts)) =
+        (Type (s, Ts)) =
       if Ts' = Ts then
         nested_self_call live_call
       else
         (List.app (check_call_dead live_call) dead_calls;
-         Type (s, map2 (freeze_fpTs fpT) (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) []
-           (transpose callss)) Ts))
-    and freeze_fpTs fpT calls (T as Type (s, _)) =
+         Type (s, map2 (freeze_fpTs_call fpT)
+           (flatten_type_args_of_bnf (the (bnf_of no_defs_lthy s)) [] (transpose callss)) Ts))
+    and freeze_fpTs_call fpT calls (T as Type (s, _)) =
         (case map_partition (try (snd o dest_map no_defs_lthy s)) calls of
           ([], _) =>
           (case map_partition (try (snd o dest_abs_or_applied_map no_defs_lthy s)) calls of
-            ([], _) => freeze_fpTs_simple calls T
+            ([], _) => freeze_fpTs_mutual_call calls T
           | callsp => freeze_fpTs_map fpT callsp T)
         | callsp => freeze_fpTs_map fpT callsp T)
-      | freeze_fpTs _ _ T = T;
+      | freeze_fpTs_call _ _ T = T;
 
     val ctr_Tsss = map (map binder_types) ctr_Tss;
-    val ctrXs_Tsss = map3 (map2 o map2 o freeze_fpTs) fpTs callssss ctr_Tsss;
+    val ctrXs_Tsss = map3 (map2 o map2 o freeze_fpTs_call) fpTs callssss ctr_Tsss;
     val ctrXs_sum_prod_Ts = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctrXs_Tsss;
-    val ctr_Ts = map (body_type o hd) ctr_Tss;
 
     val ns = map length ctr_Tsss;
     val kss = map (fn n => 1 upto n) ns;
@@ -270,6 +268,8 @@
 
 fun retypargs tyargs (Type (s, _)) = Type (s, tyargs);
 
+fun exists_strict_subtype_in Ts T = exists_subtype_in (filter_out (curry (op =) T) Ts) T;
+
 fun fold_subtype_pairs f (T as Type (s, Ts), U as Type (s', Us)) =
     f (T, U) #> (if s = s' then fold (fold_subtype_pairs f) (Ts ~~ Us) else I)
   | fold_subtype_pairs f TU = f TU;
@@ -279,7 +279,6 @@
     val qsoty = quote o Syntax.string_of_typ lthy;
     val qsotys = space_implode " or " o map qsoty;
 
-    fun duplicate_datatype T = error (qsoty T ^ " is not mutually recursive with itself");
     fun not_co_datatype0 T = error (qsoty T ^ " is not a " ^ co_prefix fp ^ "datatype");
     fun not_co_datatype (T as Type (s, _)) =
         if fp = Least_FP andalso
@@ -290,11 +289,10 @@
       | not_co_datatype T = not_co_datatype0 T;
     fun not_mutually_nested_rec Ts1 Ts2 =
       error (qsotys Ts1 ^ " is neither mutually recursive with " ^ qsotys Ts2 ^
-        " nor nested recursive via " ^ qsotys Ts2);
+        " nor nested recursive through " ^
+        (if Ts1 = Ts2 andalso can the_single Ts1 then "itself" else qsotys Ts2));
 
-    val _ = (case Library.duplicates (op =) actual_Ts of [] => () | T :: _ => duplicate_datatype T);
-
-    val perm_actual_Ts =
+    val sorted_actual_Ts =
       sort (prod_ord int_ord Term_Ord.typ_ord o pairself (`Term.size_of_typ)) actual_Ts;
 
     fun the_ctrs_of (Type (s, Ts)) = map (mk_ctr Ts) (#ctrs (the (ctr_sugar_of lthy s)));
@@ -323,7 +321,7 @@
           val {fp_res = {Ts = mutual_Ts0, ...}, ...} = the_fp_sugar_of T;
           val mutual_Ts = map (retypargs tyargs) mutual_Ts0;
 
-          val _ = seen = [] orelse exists (exists_subtype_in seen) mutual_Ts orelse
+          val _ = seen = [] orelse exists (exists_strict_subtype_in seen) mutual_Ts orelse
             not_mutually_nested_rec mutual_Ts seen;
 
           fun fresh_tyargs () =
@@ -354,17 +352,18 @@
               fresh_tyargs ();
 
           val gen_mutual_Ts = map (retypargs gen_tyargs) mutual_Ts0;
-          val Ts' = filter_out (member (op =) mutual_Ts) Ts;
+          val other_mutual_Ts = remove1 (op =) T mutual_Ts;
+          val Ts' = fold (remove1 (op =)) other_mutual_Ts Ts;
         in
           gather_types lthy' rho' (num_groups + 1) (seen @ mutual_Ts) (gen_seen' @ gen_mutual_Ts)
             Ts'
         end
       | gather_types _ _ _ _ _ (T :: _) = not_co_datatype T;
 
-    val (num_groups, perm_Ts, perm_gen_Ts) = gather_types lthy [] 0 [] [] perm_actual_Ts;
+    val (num_groups, perm_Ts, perm_gen_Ts) = gather_types lthy [] 0 [] [] sorted_actual_Ts;
     val perm_frozen_gen_Ts = map Logic.unvarifyT_global perm_gen_Ts;
 
-    val missing_Ts = perm_Ts |> subtract (op =) actual_Ts;
+    val missing_Ts = subtract (op =) actual_Ts perm_Ts;
     val Ts = actual_Ts @ missing_Ts;
 
     val nn = length Ts;
--- a/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_gfp_rec_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -407,8 +407,8 @@
     val fun_arg_hs =
       flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss);
 
-    fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
-    fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
+    fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
+    fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
 
     val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms;
 
--- a/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_compat.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -67,8 +67,7 @@
           val mutual_kks = kk upto kk + mutual_nn - 1;
           val mutual_Tkks = mutual_Ts ~~ mutual_kks;
 
-          fun Tindices_of_ctr_arg (parent_Tkks as (_, parent_kk) :: _) (U as Type (s, _))
-                (accum as (Tkssss, kk')) =
+          fun Tindices_of_ctr_arg parent_Tkks (U as Type (s, _)) (accum as (Tkssss, kk')) =
               if s = @{type_name fun} then
                 if exists_subtype_in mutual_Ts U then
                   (warning "Incomplete support for recursion through functions -- \
--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -162,8 +162,8 @@
     val perm_fun_arg_Tssss =
       mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
 
-    fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
-    fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
+    fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
+    fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
 
     val induct_thms = unpermute0 (conj_dests nn induct_thm);
 
--- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -195,7 +195,7 @@
 val code_nitpicksimp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs;
 val code_nitpicksimp_simp_attrs = code_nitpicksimp_attrs @ simp_attrs;
 
-fun unflat_lookup eq xs ys = map (fn xs' => permute_like eq xs xs' ys);
+fun unflat_lookup eq xs ys = map (fn xs' => permute_like_unique eq xs xs' ys);
 
 fun mk_half_pairss' _ ([], []) = []
   | mk_half_pairss' indent (x :: xs, _ :: ys) =
--- a/src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML	Fri Feb 14 15:03:23 2014 +0100
+++ b/src/HOL/Tools/Ctr_Sugar/ctr_sugar_util.ML	Fri Feb 14 15:03:24 2014 +0100
@@ -19,7 +19,8 @@
   val transpose: 'a list list -> 'a list list
   val pad_list: 'a -> int -> 'a list -> 'a list
   val splice: 'a list -> 'a list -> 'a list
-  val permute_like: ('a * 'b -> bool) -> 'a list -> 'b list -> 'c list -> 'c list
+  val permute_like_unique: ('a * 'b -> bool) -> 'a list -> 'b list -> 'c list -> 'c list
+  val permute_like: ('a * 'a -> bool) -> 'a list -> 'a list -> 'b list -> 'b list
 
   val mk_names: int -> string -> string list
   val mk_fresh_names: Proof.context -> int -> string -> string list * Proof.context
@@ -129,7 +130,18 @@
 
 fun splice xs ys = flat (map2 (fn x => fn y => [x, y]) xs ys);
 
-fun permute_like eq xs xs' ys = map (nth ys o (fn y => find_index (fn x => eq (x, y)) xs)) xs';
+fun permute_like_unique eq xs xs' ys =
+  map (nth ys o (fn y => find_index (fn x => eq (x, y)) xs)) xs';
+
+fun fresh eq x names =
+  (case AList.lookup eq names x of
+    NONE => ((x, 0), (x, 0) :: names)
+  | SOME n => ((x, n + 1), AList.update eq (x, n + 1) names));
+
+fun deambiguate eq xs = fst (fold_map (fresh eq) xs []);
+
+fun permute_like eq xs xs' =
+  permute_like_unique (eq_pair eq (op =)) (deambiguate eq xs) (deambiguate eq xs');
 
 fun mk_names n x = if n = 1 then [x] else map (fn i => x ^ string_of_int i) (1 upto n);
 fun mk_fresh_names ctxt = (fn xs => Variable.variant_fixes xs ctxt) oo mk_names;