Added support for parametric datatypes.
authorberghofe
Mon, 10 Nov 2008 17:38:23 +0100
changeset 28731 c60ac7923a06
parent 28730 71c946ce7eb9
child 28732 99492b224b7b
Added support for parametric datatypes.
src/HOL/Nominal/nominal_package.ML
--- a/src/HOL/Nominal/nominal_package.ML	Mon Nov 10 17:37:25 2008 +0100
+++ b/src/HOL/Nominal/nominal_package.ML	Mon Nov 10 17:38:23 2008 +0100
@@ -134,10 +134,8 @@
         val (bT as Type (b, []), _) = dest_permT U
       in if aT mem permTs_of u andalso aT <> bT then
           let
-            val a' = Sign.base_name a;
-            val b' = Sign.base_name b;
-            val cp = PureThy.get_thm thy ("cp_" ^ a' ^ "_" ^ b' ^ "_inst");
-            val dj = PureThy.get_thm thy ("dj_" ^ b' ^ "_" ^ a');
+            val cp = cp_inst_of thy a b;
+            val dj = dj_thm_of thy b a;
             val dj_cp' = [cp, dj] MRS dj_cp;
             val cert = SOME o cterm_of thy
           in
@@ -203,12 +201,6 @@
         (tname, length tvs, mx)) dts);
 
     val atoms = atoms_of thy;
-    val classes = map (NameSpace.map_base (fn s => "pt_" ^ s)) atoms;
-    val cp_classes = List.concat (map (fn atom1 => map (fn atom2 =>
-      Sign.intern_class thy ("cp_" ^ Sign.base_name atom1 ^ "_" ^
-        Sign.base_name atom2)) atoms) atoms);
-    fun augment_sort S = S union classes;
-    val augment_sort_typ = map_type_tfree (fn (s, S) => TFree (s, augment_sort S));
 
     fun prep_constr ((constrs, sorts), (cname, cargs, mx)) =
       let val (cargs', sorts') = Library.foldl (prep_typ tmp_thy) (([], sorts), cargs)
@@ -219,9 +211,16 @@
       in (dts @ [(tvs, tname, mx, constrs')], sorts') end
 
     val (dts', sorts) = Library.foldl prep_dt_spec (([], []), dts);
-    val sorts' = map (apsnd augment_sort) sorts;
     val tyvars = map #1 dts';
 
+    fun inter_sort thy S S' = Type.inter_sort (Sign.tsig_of thy) (S, S');
+    fun augment_sort_typ thy S =
+      let val S = Sign.certify_sort thy S
+      in map_type_tfree (fn (s, S') => TFree (s,
+        if member (op = o apsnd fst) sorts s then inter_sort thy S S' else S'))
+      end;
+    fun augment_sort thy S = map_types (augment_sort_typ thy S);
+
     val types_syntax = map (fn (tvs, tname, mx, constrs) => (tname, mx)) dts';
     val constr_syntax = map (fn (tvs, tname, mx, constrs) =>
       map (fn (cname, cargs, mx) => (cname, mx)) constrs) dts';
@@ -238,7 +237,7 @@
 
     val dts'' = map (fn (tvs, tname, mx, constrs) => (tvs, tname ^ "_Rep", NoSyn,
       map (fn (cname, cargs, mx) => (cname ^ "_Rep",
-        map (augment_sort_typ o replace_types) cargs, NoSyn)) constrs)) dts';
+        map replace_types cargs, NoSyn)) constrs)) dts';
 
     val new_type_names' = map (fn n => n ^ "_Rep") new_type_names;
     val full_new_type_names' = map (Sign.full_name thy) new_type_names';
@@ -248,7 +247,7 @@
 
     val SOME {descr, ...} = Symtab.lookup
       (DatatypePackage.get_datatypes thy1) (hd full_new_type_names');
-    fun nth_dtyp i = typ_of_dtyp descr sorts' (DtRec i);
+    fun nth_dtyp i = typ_of_dtyp descr sorts (DtRec i);
 
     val big_name = space_implode "_" new_type_names;
 
@@ -271,7 +270,7 @@
       let val T = nth_dtyp i
       in map (fn (cname, dts) =>
         let
-          val Ts = map (typ_of_dtyp descr sorts') dts;
+          val Ts = map (typ_of_dtyp descr sorts) dts;
           val names = Name.variant_list ["pi"] (DatatypeProp.make_tnames Ts);
           val args = map Free (names ~~ Ts);
           val c = Const (cname, Ts ---> T);
@@ -336,13 +335,14 @@
       let val permT = mk_permT (Type (a, []))
       in map standard (List.take (split_conj_thm
         (Goal.prove_global thy2 [] []
-          (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
-            (map (fn ((s, T), x) => HOLogic.mk_eq
-                (Const (s, permT --> T --> T) $
-                   Const ("List.list.Nil", permT) $ Free (x, T),
-                 Free (x, T)))
-             (perm_names ~~
-              map body_type perm_types ~~ perm_indnames))))
+          (augment_sort thy2 [pt_class_of thy2 a]
+            (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
+              (map (fn ((s, T), x) => HOLogic.mk_eq
+                  (Const (s, permT --> T --> T) $
+                     Const ("List.list.Nil", permT) $ Free (x, T),
+                   Free (x, T)))
+               (perm_names ~~
+                map body_type perm_types ~~ perm_indnames)))))
           (fn _ => EVERY [indtac induction perm_indnames 1,
             ALLGOALS (asm_full_simp_tac (simpset_of thy2))])),
         length new_type_names))
@@ -362,11 +362,12 @@
         val permT = mk_permT (Type (a, []));
         val pi1 = Free ("pi1", permT);
         val pi2 = Free ("pi2", permT);
-        val pt_inst = PureThy.get_thm thy2 ("pt_" ^ Sign.base_name a ^ "_inst");
+        val pt_inst = pt_inst_of thy2 a;
         val pt2' = pt_inst RS pt2;
         val pt2_ax = PureThy.get_thm thy2 (NameSpace.map_base (fn s => "pt_" ^ s ^ "2") a);
       in List.take (map standard (split_conj_thm
         (Goal.prove_global thy2 [] []
+           (augment_sort thy2 [pt_class_of thy2 a]
              (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
                 (map (fn ((s, T), x) =>
                     let val perm = Const (s, permT --> T --> T)
@@ -376,7 +377,7 @@
                        perm $ pi1 $ (perm $ pi2 $ Free (x, T)))
                     end)
                   (perm_names ~~
-                   map body_type perm_types ~~ perm_indnames))))
+                   map body_type perm_types ~~ perm_indnames)))))
            (fn _ => EVERY [indtac induction perm_indnames 1,
               ALLGOALS (asm_full_simp_tac (simpset_of thy2 addsimps [pt2', pt2_ax]))]))),
          length new_type_names)
@@ -394,14 +395,14 @@
         val permT = mk_permT (Type (a, []));
         val pi1 = Free ("pi1", permT);
         val pi2 = Free ("pi2", permT);
-        (*FIXME: not robust - better access these theorems using NominalData?*)
-        val at_inst = PureThy.get_thm thy2 ("at_" ^ Sign.base_name a ^ "_inst");
-        val pt_inst = PureThy.get_thm thy2 ("pt_" ^ Sign.base_name a ^ "_inst");
+        val at_inst = at_inst_of thy2 a;
+        val pt_inst = pt_inst_of thy2 a;
         val pt3' = pt_inst RS pt3;
         val pt3_rev' = at_inst RS (pt_inst RS pt3_rev);
         val pt3_ax = PureThy.get_thm thy2 (NameSpace.map_base (fn s => "pt_" ^ s ^ "3") a);
       in List.take (map standard (split_conj_thm
-        (Goal.prove_global thy2 [] [] (Logic.mk_implies
+        (Goal.prove_global thy2 [] []
+          (augment_sort thy2 [pt_class_of thy2 a] (Logic.mk_implies
              (HOLogic.mk_Trueprop (Const ("Nominal.prm_eq",
                 permT --> permT --> HOLogic.boolT) $ pi1 $ pi2),
               HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
@@ -412,7 +413,7 @@
                        perm $ pi2 $ Free (x, T))
                     end)
                   (perm_names ~~
-                   map body_type perm_types ~~ perm_indnames)))))
+                   map body_type perm_types ~~ perm_indnames))))))
            (fn _ => EVERY [indtac induction perm_indnames 1,
               ALLGOALS (asm_full_simp_tac (simpset_of thy2 addsimps [pt3', pt3_rev', pt3_ax]))]))),
          length new_type_names)
@@ -428,29 +429,29 @@
 
     fun composition_instance name1 name2 thy =
       let
-        val name1' = Sign.base_name name1;
-        val name2' = Sign.base_name name2;
-        val cp_class = Sign.intern_class thy ("cp_" ^ name1' ^ "_" ^ name2');
+        val cp_class = cp_class_of thy name1 name2;
+        val pt_class =
+          if name1 = name2 then [pt_class_of thy name1]
+          else [];
         val permT1 = mk_permT (Type (name1, []));
         val permT2 = mk_permT (Type (name2, []));
-        val augment = map_type_tfree
-          (fn (x, S) => TFree (x, cp_class :: S));
-        val Ts = map (augment o body_type) perm_types;
-        val cp_inst = PureThy.get_thm thy ("cp_" ^ name1' ^ "_" ^ name2' ^ "_inst");
+        val Ts = map body_type perm_types;
+        val cp_inst = cp_inst_of thy name1 name2;
         val simps = simpset_of thy addsimps (perm_fun_def ::
           (if name1 <> name2 then
-             let val dj = PureThy.get_thm thy ("dj_" ^ name2' ^ "_" ^ name1')
+             let val dj = dj_thm_of thy name2 name1
              in [dj RS (cp_inst RS dj_cp), dj RS dj_perm_perm_forget] end
            else
              let
-               val at_inst = PureThy.get_thm thy ("at_" ^ name1' ^ "_inst");
-               val pt_inst = PureThy.get_thm thy ("pt_" ^ name1' ^ "_inst");
+               val at_inst = at_inst_of thy name1;
+               val pt_inst = pt_inst_of thy name1;
              in
                [cp_inst RS cp1 RS sym,
                 at_inst RS (pt_inst RS pt_perm_compose) RS sym,
                 at_inst RS (pt_inst RS pt_perm_compose_rev) RS sym]
             end))
         val thms = split_conj_thm (Goal.prove_global thy [] []
+          (augment_sort thy (cp_class :: pt_class)
             (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
               (map (fn ((s, T), x) =>
                   let
@@ -463,25 +464,30 @@
                     (perm1 $ pi1 $ (perm2 $ pi2 $ Free (x, T)),
                      perm2 $ (perm3 $ pi1 $ pi2) $ (perm1 $ pi1 $ Free (x, T)))
                   end)
-                (perm_names ~~ Ts ~~ perm_indnames))))
+                (perm_names ~~ Ts ~~ perm_indnames)))))
           (fn _ => EVERY [indtac induction perm_indnames 1,
              ALLGOALS (asm_full_simp_tac simps)]))
       in
         foldl (fn ((s, tvs), thy) => AxClass.prove_arity
-            (s, replicate (length tvs) (cp_class :: classes), [cp_class])
+            (s, replicate (length tvs) (cp_class :: pt_class), [cp_class])
             (Class.intro_classes_tac [] THEN ALLGOALS (resolve_tac thms)) thy)
           thy (full_new_type_names' ~~ tyvars)
       end;
 
     val (perm_thmss,thy3) = thy2 |>
       fold (fn name1 => fold (composition_instance name1) atoms) atoms |>
-      curry (Library.foldr (fn ((i, (tyname, args, _)), thy) =>
-        AxClass.prove_arity (tyname, replicate (length args) classes, classes)
-        (Class.intro_classes_tac [] THEN REPEAT (EVERY
-           [resolve_tac perm_empty_thms 1,
-            resolve_tac perm_append_thms 1,
-            resolve_tac perm_eq_thms 1, assume_tac 1])) thy))
-        (List.take (descr, length new_type_names)) |>
+      fold (fn atom => fn thy =>
+        let val pt_name = pt_class_of thy atom
+        in
+          fold (fn (i, (tyname, args, _)) => AxClass.prove_arity
+              (tyname, replicate (length args) [pt_name], [pt_name])
+              (EVERY
+                [Class.intro_classes_tac [],
+                 resolve_tac perm_empty_thms 1,
+                 resolve_tac perm_append_thms 1,
+                 resolve_tac perm_eq_thms 1, assume_tac 1]))
+            (List.take (descr, length new_type_names)) thy
+        end) atoms |>
       PureThy.add_thmss
         [((space_implode "_" new_type_names ^ "_unfolded_perm_eq",
           unfolded_perm_eq_thms), [Simplifier.simp_add]),
@@ -513,9 +519,10 @@
           apfst (cons dt) (strip_option dt')
       | strip_option dt = ([], dt);
 
-    val dt_atomTs = distinct op = (map (typ_of_dtyp descr sorts')
+    val dt_atomTs = distinct op = (map (typ_of_dtyp descr sorts)
       (List.concat (map (fn (_, (_, _, cs)) => List.concat
         (map (List.concat o map (fst o strip_option) o snd) cs)) descr)));
+    val dt_atoms = map (fst o dest_Type) dt_atomTs;
 
     fun make_intr s T (cname, cargs) =
       let
@@ -523,9 +530,9 @@
           let
             val (dts, dt') = strip_option dt;
             val (dts', dt'') = strip_dtyp dt';
-            val Ts = map (typ_of_dtyp descr sorts') dts;
-            val Us = map (typ_of_dtyp descr sorts') dts';
-            val T = typ_of_dtyp descr sorts' dt'';
+            val Ts = map (typ_of_dtyp descr sorts) dts;
+            val Us = map (typ_of_dtyp descr sorts) dts';
+            val T = typ_of_dtyp descr sorts dt'';
             val free = mk_Free "x" (Us ---> T) j;
             val free' = app_bnds free (length Us);
             fun mk_abs_fun (T, (i, t)) =
@@ -580,23 +587,22 @@
 
     fun mk_perm_closed name = map (fn th => standard (th RS mp))
       (List.take (split_conj_thm (Goal.prove_global thy4 [] []
-        (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj (map
-           (fn ((s, T), x) =>
-              let
-                val T = map_type_tfree
-                  (fn (s, cs) => TFree (s, cs union cp_classes)) T;
-                val S = Const (s, T --> HOLogic.boolT);
-                val permT = mk_permT (Type (name, []))
-              in HOLogic.mk_imp (S $ Free (x, T),
-                S $ (Const ("Nominal.perm", permT --> T --> T) $
-                  Free ("pi", permT) $ Free (x, T)))
-              end) (rep_set_names'' ~~ recTs' ~~ perm_indnames'))))
-        (fn _ => EVERY (* CU: added perm_fun_def in the final tactic in order to deal with funs *)
+        (augment_sort thy4
+          (pt_class_of thy4 name :: map (cp_class_of thy4 name) (dt_atoms \ name))
+          (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj (map
+            (fn ((s, T), x) =>
+               let
+                 val S = Const (s, T --> HOLogic.boolT);
+                 val permT = mk_permT (Type (name, []))
+               in HOLogic.mk_imp (S $ Free (x, T),
+                 S $ (Const ("Nominal.perm", permT --> T --> T) $
+                   Free ("pi", permT) $ Free (x, T)))
+               end) (rep_set_names'' ~~ recTs' ~~ perm_indnames')))))
+        (fn _ => EVERY
            [indtac rep_induct [] 1,
             ALLGOALS (simp_tac (simpset_of thy4 addsimps
               (symmetric perm_fun_def :: abs_perm))),
-            ALLGOALS (resolve_tac rep_intrs
-               THEN_ALL_NEW (asm_full_simp_tac (simpset_of thy4 addsimps [perm_fun_def])))])),
+            ALLGOALS (resolve_tac rep_intrs THEN_ALL_NEW assume_tac)])),
         length new_type_names));
 
     val perm_closed_thmss = map mk_perm_closed atoms;
@@ -617,7 +623,7 @@
         let
           val permT = mk_permT (TFree (Name.variant tvs "'a", HOLogic.typeS));
           val pi = Free ("pi", permT);
-          val tvs' = map (fn s => TFree (s, the (AList.lookup op = sorts' s))) tvs;
+          val tvs' = map (fn s => TFree (s, the (AList.lookup op = sorts s))) tvs;
           val T = Type (Sign.intern_type thy name, tvs');
         in apfst (pair r o hd)
           (PureThy.add_defs_unchecked true [(("prm_" ^ name ^ "_def", Logic.mk_equals
@@ -641,19 +647,23 @@
 
     val _ = warning "prove that new types are in class pt_<name> ...";
 
-    fun pt_instance ((class, atom), perm_closed_thms) =
+    fun pt_instance (atom, perm_closed_thms) =
       fold (fn ((((((Abs_inverse, Rep_inverse), Rep),
         perm_def), name), tvs), perm_closed) => fn thy =>
-          AxClass.prove_arity
+          let
+            val pt_class = pt_class_of thy atom;
+            val cp_sort = map (cp_class_of thy atom) (dt_atoms \ atom)
+          in AxClass.prove_arity
             (Sign.intern_type thy name,
-              replicate (length tvs) (classes @ cp_classes), [class])
+              replicate (length tvs) (pt_class :: cp_sort), [pt_class])
             (EVERY [Class.intro_classes_tac [],
               rewrite_goals_tac [perm_def],
               asm_full_simp_tac (simpset_of thy addsimps [Rep_inverse]) 1,
               asm_full_simp_tac (simpset_of thy addsimps
                 [Rep RS perm_closed RS Abs_inverse]) 1,
               asm_full_simp_tac (HOL_basic_ss addsimps [PureThy.get_thm thy
-                ("pt_" ^ Sign.base_name atom ^ "3")]) 1]) thy)
+                ("pt_" ^ Sign.base_name atom ^ "3")]) 1]) thy
+          end)
         (Abs_inverse_thms ~~ Rep_inverse_thms ~~ Rep_thms ~~ perm_defs ~~
            new_type_names ~~ tyvars ~~ perm_closed_thms);
 
@@ -664,14 +674,17 @@
 
     fun cp_instance (atom1, perm_closed_thms1) (atom2, perm_closed_thms2) thy =
       let
-        val name = "cp_" ^ Sign.base_name atom1 ^ "_" ^ Sign.base_name atom2;
-        val class = Sign.intern_class thy name;
-        val cp1' = PureThy.get_thm thy (name ^ "_inst") RS cp1
+        val cp_class = cp_class_of thy atom1 atom2;
+        val sort =
+          pt_class_of thy atom1 :: map (cp_class_of thy atom1) (dt_atoms \ atom1) @
+          (if atom1 = atom2 then [cp_class_of thy atom1 atom1] else
+           pt_class_of thy atom2 :: map (cp_class_of thy atom2) (dt_atoms \ atom2));
+        val cp1' = cp_inst_of thy atom1 atom2 RS cp1
       in fold (fn ((((((Abs_inverse, Rep),
         perm_def), name), tvs), perm_closed1), perm_closed2) => fn thy =>
           AxClass.prove_arity
             (Sign.intern_type thy name,
-              replicate (length tvs) (classes @ cp_classes), [class])
+              replicate (length tvs) sort, [cp_class])
             (EVERY [Class.intro_classes_tac [],
               rewrite_goals_tac [perm_def],
               asm_full_simp_tac (simpset_of thy addsimps
@@ -687,8 +700,8 @@
 
     val thy7 = fold (fn x => fn thy => thy |>
       pt_instance x |>
-      fold (cp_instance (apfst snd x)) (atoms ~~ perm_closed_thmss))
-        (classes ~~ atoms ~~ perm_closed_thmss) thy6;
+      fold (cp_instance x) (atoms ~~ perm_closed_thmss))
+        (atoms ~~ perm_closed_thmss) thy6;
 
     (**** constructors ****)
 
@@ -741,14 +754,14 @@
       map (fn ((cname, cargs), idxs) => (cname, partition_cargs idxs cargs))
         (constrs ~~ idxss)))) (descr'' ~~ ndescr);
 
-    fun nth_dtyp' i = typ_of_dtyp descr'' sorts' (DtRec i);
+    fun nth_dtyp' i = typ_of_dtyp descr'' sorts (DtRec i);
 
     val rep_names = map (fn s =>
       Sign.intern_const thy7 ("Rep_" ^ s)) new_type_names;
     val abs_names = map (fn s =>
       Sign.intern_const thy7 ("Abs_" ^ s)) new_type_names;
 
-    val recTs = get_rec_types descr'' sorts';
+    val recTs = get_rec_types descr'' sorts;
     val newTs' = Library.take (length new_type_names, recTs');
     val newTs = Library.take (length new_type_names, recTs);
 
@@ -759,17 +772,17 @@
       let
         fun constr_arg ((dts, dt), (j, l_args, r_args)) =
           let
-            val xs = map (fn (dt, i) => mk_Free "x" (typ_of_dtyp descr'' sorts' dt) i)
+            val xs = map (fn (dt, i) => mk_Free "x" (typ_of_dtyp descr'' sorts dt) i)
               (dts ~~ (j upto j + length dts - 1))
-            val x = mk_Free "x" (typ_of_dtyp descr'' sorts' dt) (j + length dts)
+            val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts)
           in
             (j + length dts + 1,
              xs @ x :: l_args,
              foldr mk_abs_fun
                (case dt of
                   DtRec k => if k < length new_type_names then
-                      Const (List.nth (rep_names, k), typ_of_dtyp descr'' sorts' dt -->
-                        typ_of_dtyp descr sorts' dt) $ x
+                      Const (List.nth (rep_names, k), typ_of_dtyp descr'' sorts dt -->
+                        typ_of_dtyp descr sorts dt) $ x
                     else error "nested recursion not (yet) supported"
                 | _ => x) xs :: r_args)
           end
@@ -834,9 +847,12 @@
         val permT = mk_permT (Type (atom, []));
         val pi = Free ("pi", permT);
       in
-        Goal.prove_global thy8 [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq
-            (Const ("Nominal.perm", permT --> U --> U) $ pi $ (Rep $ x),
-             Rep $ (Const ("Nominal.perm", permT --> T --> T) $ pi $ x))))
+        Goal.prove_global thy8 [] []
+          (augment_sort thy8
+            (pt_class_of thy8 atom :: map (cp_class_of thy8 atom) (dt_atoms \ atom))
+            (HOLogic.mk_Trueprop (HOLogic.mk_eq
+              (Const ("Nominal.perm", permT --> U --> U) $ pi $ (Rep $ x),
+               Rep $ (Const ("Nominal.perm", permT --> T --> T) $ pi $ x)))))
           (fn _ => simp_tac (HOL_basic_ss addsimps (perm_defs @ Abs_inverse_thms @
             perm_closed_thms @ Rep_thms)) 1)
       end) Rep_thms;
@@ -846,7 +862,7 @@
 
     (* prove distinctness theorems *)
 
-    val distinct_props = DatatypeProp.make_distincts descr' sorts';
+    val distinct_props = DatatypeProp.make_distincts descr' sorts;
     val dist_rewrites = map2 (fn rep_thms => fn dist_lemma =>
       dist_lemma :: rep_thms @ [In0_eq, In1_eq, In0_not_In1, In1_not_In0])
         constr_rep_thmss dist_lemmas;
@@ -881,10 +897,10 @@
 
           fun constr_arg ((dts, dt), (j, l_args, r_args)) =
             let
-              val Ts = map (typ_of_dtyp descr'' sorts') dts;
+              val Ts = map (typ_of_dtyp descr'' sorts) dts;
               val xs = map (fn (T, i) => mk_Free "x" T i)
                 (Ts ~~ (j upto j + length dts - 1))
-              val x = mk_Free "x" (typ_of_dtyp descr'' sorts' dt) (j + length dts)
+              val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts)
             in
               (j + length dts + 1,
                xs @ x :: l_args,
@@ -895,8 +911,10 @@
           val c = Const (cname, map fastype_of l_args ---> T)
         in
           Goal.prove_global thy8 [] []
-            (HOLogic.mk_Trueprop (HOLogic.mk_eq
-              (perm (list_comb (c, l_args)), list_comb (c, r_args))))
+            (augment_sort thy8
+              (pt_class_of thy8 atom :: map (cp_class_of thy8 atom) (dt_atoms \ atom))
+              (HOLogic.mk_Trueprop (HOLogic.mk_eq
+                (perm (list_comb (c, l_args)), list_comb (c, r_args)))))
             (fn _ => EVERY
               [simp_tac (simpset_of thy8 addsimps (constr_rep_thm :: perm_defs)) 1,
                simp_tac (HOL_basic_ss addsimps (Rep_thms @ Abs_inverse_thms @
@@ -915,6 +933,10 @@
     val alpha = PureThy.get_thms thy8 "alpha";
     val abs_fresh = PureThy.get_thms thy8 "abs_fresh";
 
+    val pt_cp_sort =
+      map (pt_class_of thy8) dt_atoms @
+      maps (fn s => map (cp_class_of thy8 s) (dt_atoms \ s)) dt_atoms;
+
     val inject_thms = map (fn (((i, (_, _, constrs)), tname), constr_rep_thms) =>
       let val T = nth_dtyp' i
       in List.mapPartial (fn ((cname, dts), constr_rep_thm) =>
@@ -925,11 +947,11 @@
 
           fun make_inj ((dts, dt), (j, args1, args2, eqs)) =
             let
-              val Ts_idx = map (typ_of_dtyp descr'' sorts') dts ~~ (j upto j + length dts - 1);
+              val Ts_idx = map (typ_of_dtyp descr'' sorts) dts ~~ (j upto j + length dts - 1);
               val xs = map (fn (T, i) => mk_Free "x" T i) Ts_idx;
               val ys = map (fn (T, i) => mk_Free "y" T i) Ts_idx;
-              val x = mk_Free "x" (typ_of_dtyp descr'' sorts' dt) (j + length dts);
-              val y = mk_Free "y" (typ_of_dtyp descr'' sorts' dt) (j + length dts)
+              val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts);
+              val y = mk_Free "y" (typ_of_dtyp descr'' sorts dt) (j + length dts)
             in
               (j + length dts + 1,
                xs @ (x :: args1), ys @ (y :: args2),
@@ -941,17 +963,17 @@
           val Ts = map fastype_of args1;
           val c = Const (cname, Ts ---> T)
         in
-          Goal.prove_global thy8 [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq
-              (HOLogic.mk_eq (list_comb (c, args1), list_comb (c, args2)),
-               foldr1 HOLogic.mk_conj eqs)))
+          Goal.prove_global thy8 [] []
+            (augment_sort thy8 pt_cp_sort
+              (HOLogic.mk_Trueprop (HOLogic.mk_eq
+                (HOLogic.mk_eq (list_comb (c, args1), list_comb (c, args2)),
+                 foldr1 HOLogic.mk_conj eqs))))
             (fn _ => EVERY
                [asm_full_simp_tac (simpset_of thy8 addsimps (constr_rep_thm ::
                   rep_inject_thms')) 1,
                 TRY (asm_full_simp_tac (HOL_basic_ss addsimps (fresh_def :: supp_def ::
                   alpha @ abs_perm @ abs_fresh @ rep_inject_thms @
-                  perm_rep_perm_thms)) 1),
-                TRY (asm_full_simp_tac (HOL_basic_ss addsimps (perm_fun_def ::
-                  @{thm expand_fun_eq} :: rep_inject_thms @ perm_rep_perm_thms)) 1)])
+                  perm_rep_perm_thms)) 1)])
         end) (constrs ~~ constr_rep_thms)
       end) (List.take (pdescr, length new_type_names) ~~ new_type_names ~~ constr_rep_thmss);
 
@@ -968,9 +990,9 @@
 
           fun process_constr ((dts, dt), (j, args1, args2)) =
             let
-              val Ts_idx = map (typ_of_dtyp descr'' sorts') dts ~~ (j upto j + length dts - 1);
+              val Ts_idx = map (typ_of_dtyp descr'' sorts) dts ~~ (j upto j + length dts - 1);
               val xs = map (fn (T, i) => mk_Free "x" T i) Ts_idx;
-              val x = mk_Free "x" (typ_of_dtyp descr'' sorts' dt) (j + length dts)
+              val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts)
             in
               (j + length dts + 1,
                xs @ (x :: args1), foldr mk_abs_fun x xs :: args2)
@@ -983,10 +1005,11 @@
             Const ("Nominal.supp", fastype_of t --> HOLogic.mk_setT atomT) $ t;
           fun fresh t = fresh_const atomT (fastype_of t) $ Free ("a", atomT) $ t;
           val supp_thm = Goal.prove_global thy8 [] []
+            (augment_sort thy8 pt_cp_sort
               (HOLogic.mk_Trueprop (HOLogic.mk_eq
                 (supp c,
                  if null dts then Const ("{}", HOLogic.mk_setT atomT)
-                 else foldr1 (HOLogic.mk_binop "op Un") (map supp args2))))
+                 else foldr1 (HOLogic.mk_binop "op Un") (map supp args2)))))
             (fn _ =>
               simp_tac (HOL_basic_ss addsimps (supp_def ::
                  Un_assoc :: de_Morgan_conj :: Collect_disj_eq :: finite_Un ::
@@ -994,10 +1017,11 @@
                  abs_perm @ abs_fresh @ inject_thms' @ perm_thms')) 1)
         in
           (supp_thm,
-           Goal.prove_global thy8 [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq
-              (fresh c,
-               if null dts then HOLogic.true_const
-               else foldr1 HOLogic.mk_conj (map fresh args2))))
+           Goal.prove_global thy8 [] [] (augment_sort thy8 pt_cp_sort
+             (HOLogic.mk_Trueprop (HOLogic.mk_eq
+               (fresh c,
+                if null dts then HOLogic.true_const
+                else foldr1 HOLogic.mk_conj (map fresh args2)))))
              (fn _ =>
                simp_tac (HOL_ss addsimps [Un_iff, empty_iff, fresh_def, supp_thm]) 1))
         end) atoms) constrs)
@@ -1038,7 +1062,7 @@
 
     val Abs_inverse_thms' = map (fn r => r RS subst) Abs_inverse_thms;
 
-    val dt_induct_prop = DatatypeProp.make_ind descr' sorts';
+    val dt_induct_prop = DatatypeProp.make_ind descr' sorts;
     val dt_induct = Goal.prove_global thy8 []
       (Logic.strip_imp_prems dt_induct_prop) (Logic.strip_imp_concl dt_induct_prop)
       (fn {prems, ...} => EVERY
@@ -1064,11 +1088,13 @@
     val finite_supp_thms = map (fn atom =>
       let val atomT = Type (atom, [])
       in map standard (List.take
-        (split_conj_thm (Goal.prove_global thy8 [] [] (HOLogic.mk_Trueprop
-           (foldr1 HOLogic.mk_conj (map (fn (s, T) =>
-             Const ("Finite_Set.finite", HOLogic.mk_setT atomT --> HOLogic.boolT) $
-               (Const ("Nominal.supp", T --> HOLogic.mk_setT atomT) $ Free (s, T)))
-               (indnames ~~ recTs))))
+        (split_conj_thm (Goal.prove_global thy8 [] []
+           (augment_sort thy8 (fs_class_of thy8 atom :: pt_cp_sort)
+             (HOLogic.mk_Trueprop
+               (foldr1 HOLogic.mk_conj (map (fn (s, T) =>
+                 Const ("Finite_Set.finite", HOLogic.mk_setT atomT --> HOLogic.boolT) $
+                   (Const ("Nominal.supp", T --> HOLogic.mk_setT atomT) $ Free (s, T)))
+                   (indnames ~~ recTs)))))
            (fn _ => indtac dt_induct indnames 1 THEN
             ALLGOALS (asm_full_simp_tac (simpset_of thy8 addsimps
               (abs_supp @ supp_atm @
@@ -1096,10 +1122,10 @@
       DatatypeAux.store_thmss "supp" new_type_names supp_thms ||>>
       DatatypeAux.store_thmss_atts "fresh" new_type_names simp_atts fresh_thms ||>
       fold (fn (atom, ths) => fn thy =>
-        let val class = Sign.intern_class thy ("fs_" ^ Sign.base_name atom)
+        let val class = fs_class_of thy atom
         in fold (fn T => AxClass.prove_arity
             (fst (dest_Type T),
-              replicate (length sorts) [class], [class])
+              replicate (length sorts) (class :: pt_cp_sort), [class])
             (Class.intro_classes_tac [] THEN resolve_tac ths 1)) newTs thy
         end) (atoms ~~ finite_supp_thms);
 
@@ -1108,8 +1134,7 @@
     val pnames = if length descr'' = 1 then ["P"]
       else map (fn i => "P" ^ string_of_int i) (1 upto length descr'');
     val ind_sort = if null dt_atomTs then HOLogic.typeS
-      else Sign.certify_sort thy9 (map (fn T => Sign.intern_class thy9 ("fs_" ^
-        Sign.base_name (fst (dest_Type T)))) dt_atomTs);
+      else Sign.certify_sort thy9 (map (fs_class_of thy9) dt_atoms);
     val fsT = TFree ("'n", ind_sort);
     val fsT' = TFree ("'n", HOLogic.typeS);
 
@@ -1134,8 +1159,8 @@
     fun make_ind_prem fsT f k T ((cname, cargs), idxs) =
       let
         val recs = List.filter is_rec_type cargs;
-        val Ts = map (typ_of_dtyp descr'' sorts') cargs;
-        val recTs' = map (typ_of_dtyp descr'' sorts') recs;
+        val Ts = map (typ_of_dtyp descr'' sorts) cargs;
+        val recTs' = map (typ_of_dtyp descr'' sorts) recs;
         val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts);
         val rec_tnames = map fst (List.filter (is_rec_type o snd) (tnames ~~ cargs));
         val frees = tnames ~~ Ts;
@@ -1199,12 +1224,10 @@
             (Free (tname, T))))
         (descr'' ~~ recTs ~~ tnames)));
 
-    val fin_set_supp = map (fn Type (s, _) =>
-      PureThy.get_thm thy9 ("at_" ^ Sign.base_name s ^ "_inst") RS
-        at_fin_set_supp) dt_atomTs;
-    val fin_set_fresh = map (fn Type (s, _) =>
-      PureThy.get_thm thy9 ("at_" ^ Sign.base_name s ^ "_inst") RS
-        at_fin_set_fresh) dt_atomTs;
+    val fin_set_supp = map (fn s =>
+      at_inst_of thy9 s RS at_fin_set_supp) dt_atoms;
+    val fin_set_fresh = map (fn s =>
+      at_inst_of thy9 s RS at_fin_set_fresh) dt_atoms;
     val pt1_atoms = map (fn Type (s, _) =>
       PureThy.get_thm thy9 ("pt_" ^ Sign.base_name s ^ "1")) dt_atomTs;
     val pt2_atoms = map (fn Type (s, _) =>
@@ -1248,6 +1271,10 @@
           [SOME (cterm_of thy a), NONE, SOME (cterm_of thy b)] th
       end;
 
+    val fs_cp_sort =
+      map (fs_class_of thy9) dt_atoms @
+      maps (fn s => map (cp_class_of thy9 s) (dt_atoms \ s)) dt_atoms;
+
     (**********************************************************************
       The subgoals occurring in the proof of induct_aux have the
       following parameters:
@@ -1263,7 +1290,9 @@
 
     val _ = warning "proving strong induction theorem ...";
 
-    val induct_aux = Goal.prove_global thy9 [] ind_prems' ind_concl' (fn {prems, context} =>
+    val induct_aux = Goal.prove_global thy9 []
+        (map (augment_sort thy9 fs_cp_sort) ind_prems')
+        (augment_sort thy9 fs_cp_sort ind_concl') (fn {prems, context} =>
       let
         val (prems1, prems2) = chop (length dt_atomTs) prems;
         val ind_ss2 = HOL_ss addsimps
@@ -1276,7 +1305,8 @@
           fin_set_fresh @ calc_atm;
         val ind_ss5 = HOL_basic_ss addsimps pt1_atoms;
         val ind_ss6 = HOL_basic_ss addsimps flat perm_simps';
-        val th = Goal.prove context [] [] aux_ind_concl
+        val th = Goal.prove context [] []
+          (augment_sort thy9 fs_cp_sort aux_ind_concl)
           (fn {context = context1, ...} =>
              EVERY (indtac dt_induct tnames 1 ::
                maps (fn ((_, (_, _, constrs)), (_, constrs')) =>
@@ -1351,17 +1381,19 @@
       end);
 
     val induct_aux' = Thm.instantiate ([],
-      map (fn (s, T) =>
-        let val pT = TVar (("'n", 0), HOLogic.typeS) --> T --> HOLogic.boolT
-        in (cterm_of thy9 (Var ((s, 0), pT)), cterm_of thy9 (Free (s, pT))) end)
-          (pnames ~~ recTs) @
+      map (fn (s, v as Var (_, T)) =>
+        (cterm_of thy9 v, cterm_of thy9 (Free (s, T))))
+          (pnames ~~ map head_of (HOLogic.dest_conj
+             (HOLogic.dest_Trueprop (concl_of induct_aux)))) @
       map (fn (_, f) =>
         let val f' = Logic.varify f
         in (cterm_of thy9 f',
           cterm_of thy9 (Const ("Nominal.supp", fastype_of f')))
         end) fresh_fs) induct_aux;
 
-    val induct = Goal.prove_global thy9 [] ind_prems ind_concl
+    val induct = Goal.prove_global thy9 []
+      (map (augment_sort thy9 fs_cp_sort) ind_prems)
+      (augment_sort thy9 fs_cp_sort ind_concl)
       (fn {prems, ...} => EVERY
          [rtac induct_aux' 1,
           REPEAT (resolve_tac fs_atoms 1),
@@ -1380,16 +1412,10 @@
 
     val used = foldr add_typ_tfree_names [] recTs;
 
-    val (rec_result_Ts', rec_fn_Ts') = DatatypeProp.make_primrec_Ts descr' sorts' used;
+    val (rec_result_Ts', rec_fn_Ts') = DatatypeProp.make_primrec_Ts descr' sorts used;
 
     val rec_sort = if null dt_atomTs then HOLogic.typeS else
-      let val names = map (Sign.base_name o fst o dest_Type) dt_atomTs
-      in Sign.certify_sort thy10 (map (Sign.intern_class thy10)
-        (map (fn s => "pt_" ^ s) names @
-         List.concat (map (fn s => List.mapPartial (fn s' =>
-           if s = s' then NONE
-           else SOME ("cp_" ^ s ^ "_" ^ s')) names) names)))
-      end;
+      Sign.certify_sort thy10 pt_cp_sort;
 
     val rec_result_Ts = map (fn TFree (s, _) => TFree (s, rec_sort)) rec_result_Ts';
     val rec_fn_Ts = map (typ_subst_atomic (rec_result_Ts' ~~ rec_result_Ts)) rec_fn_Ts';
@@ -1429,7 +1455,7 @@
     fun make_rec_intr T p rec_set ((rec_intr_ts, rec_prems, rec_prems',
           rec_eq_prems, l), ((cname, cargs), idxs)) =
       let
-        val Ts = map (typ_of_dtyp descr'' sorts') cargs;
+        val Ts = map (typ_of_dtyp descr'' sorts) cargs;
         val frees = map (fn i => "x" ^ string_of_int i) (1 upto length Ts) ~~ Ts;
         val frees' = partition_cargs idxs frees;
         val binders = List.concat (map fst frees');
@@ -1508,7 +1534,8 @@
           end) (recTs ~~ rec_result_Ts ~~ rec_sets ~~ rec_sets_pi ~~ (1 upto length recTs));
         val ths = map (fn th => standard (th RS mp)) (split_conj_thm
           (Goal.prove_global thy11 [] []
-            (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj (map HOLogic.mk_imp ps)))
+            (augment_sort thy1 pt_cp_sort
+              (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj (map HOLogic.mk_imp ps))))
             (fn _ => rtac rec_induct 1 THEN REPEAT
                (simp_tac (Simplifier.theory_context thy11 HOL_basic_ss
                   addsimps flat perm_simps'
@@ -1517,7 +1544,8 @@
                  asm_simp_tac (HOL_ss addsimps (fresh_bij @ perm_bij))) 1))))
         val ths' = map (fn ((P, Q), th) =>
           Goal.prove_global thy11 [] []
-            (Logic.mk_implies (HOLogic.mk_Trueprop Q, HOLogic.mk_Trueprop P))
+            (augment_sort thy1 pt_cp_sort
+              (Logic.mk_implies (HOLogic.mk_Trueprop Q, HOLogic.mk_Trueprop P)))
             (fn _ => dtac (Thm.instantiate ([],
                  [(cterm_of thy11 (Var (("pi", 0), permT)),
                    cterm_of thy11 (Const ("List.rev", permT --> permT) $ pi))]) th) 1 THEN
@@ -1537,16 +1565,19 @@
             (rec_fns ~~ rec_fn_Ts)
       in
         map (fn th => standard (th RS mp)) (split_conj_thm
-          (Goal.prove_global thy11 [] fins
-            (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
-              (map (fn (((T, U), R), i) =>
-                 let
-                   val x = Free ("x" ^ string_of_int i, T);
-                   val y = Free ("y" ^ string_of_int i, U)
-                 in
-                   HOLogic.mk_imp (R $ x $ y,
-                     finite $ (Const ("Nominal.supp", U --> aset) $ y))
-                 end) (recTs ~~ rec_result_Ts ~~ rec_sets ~~ (1 upto length recTs)))))
+          (Goal.prove_global thy11 []
+            (map (augment_sort thy11 fs_cp_sort) fins)
+            (augment_sort thy11 fs_cp_sort
+              (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
+                (map (fn (((T, U), R), i) =>
+                   let
+                     val x = Free ("x" ^ string_of_int i, T);
+                     val y = Free ("y" ^ string_of_int i, U)
+                   in
+                     HOLogic.mk_imp (R $ x $ y,
+                       finite $ (Const ("Nominal.supp", U --> aset) $ y))
+                   end) (recTs ~~ rec_result_Ts ~~ rec_sets ~~
+                     (1 upto length recTs))))))
             (fn {prems = fins, ...} =>
               (rtac rec_induct THEN_ALL_NEW cut_facts_tac fins) 1 THEN REPEAT
                (NominalPermeq.finite_guess_tac (HOL_ss addsimps [fs_name]) 1))))
@@ -1560,6 +1591,8 @@
            (Const ("Nominal.supp", T --> HOLogic.mk_setT aT) $ f)))
            (rec_fns ~~ rec_fn_Ts)) dt_atomTs;
 
+    val rec_fns' = map (augment_sort thy11 fs_cp_sort) rec_fns;
+
     val rec_fresh_thms = map (fn ((aT, eqvt_ths), finite_prems) =>
       let
         val name = Sign.base_name (fst (dest_Type aT));
@@ -1570,16 +1603,18 @@
       in
         map (fn (((T, U), R), eqvt_th) =>
           let
-            val x = Free ("x", T);
+            val x = Free ("x", augment_sort_typ thy11 fs_cp_sort T);
             val y = Free ("y", U);
             val y' = Free ("y'", U)
           in
-            standard (Goal.prove (ProofContext.init thy11) [] (finite_prems @
-                [HOLogic.mk_Trueprop (R $ x $ y),
-                 HOLogic.mk_Trueprop (HOLogic.mk_all ("y'", U,
-                   HOLogic.mk_imp (R $ x $ y', HOLogic.mk_eq (y', y)))),
-                 HOLogic.mk_Trueprop (fresh_const aT T $ a $ x)] @
-              freshs)
+            standard (Goal.prove (ProofContext.init thy11) []
+              (map (augment_sort thy11 fs_cp_sort)
+                (finite_prems @
+                   [HOLogic.mk_Trueprop (R $ x $ y),
+                    HOLogic.mk_Trueprop (HOLogic.mk_all ("y'", U,
+                      HOLogic.mk_imp (R $ x $ y', HOLogic.mk_eq (y', y)))),
+                    HOLogic.mk_Trueprop (fresh_const aT T $ a $ x)] @
+                 freshs))
               (HOLogic.mk_Trueprop (fresh_const aT U $ a $ y))
               (fn {prems, context} =>
                  let
@@ -1588,7 +1623,7 @@
                    val unique_prem' = unique_prem RS spec RS mp;
                    val unique = [unique_prem', unique_prem' RS sym] MRS trans;
                    val _ $ (_ $ (_ $ S $ _)) $ _ = prop_of supports_fresh;
-                   val tuple = foldr1 HOLogic.mk_prod (x :: rec_fns)
+                   val tuple = foldr1 HOLogic.mk_prod (x :: rec_fns')
                  in EVERY
                    [rtac (Drule.cterm_instantiate
                       [(cterm_of thy11 S,
@@ -1631,7 +1666,7 @@
       (rec_unique_frees ~~ rec_result_Ts ~~ rec_sets);
 
     val induct_aux_rec = Drule.cterm_instantiate
-      (map (pairself (cterm_of thy11))
+      (map (pairself (cterm_of thy11) o apsnd (augment_sort thy11 fs_cp_sort))
          (map (fn (aT, f) => (Logic.varify f, Abs ("z", HOLogic.unitT,
             Const ("Nominal.supp", fun_tupleT --> HOLogic.mk_setT aT) $ fun_tuple)))
               fresh_fs @
@@ -1668,8 +1703,10 @@
 
     val rec_unique_thms = split_conj_thm (Goal.prove
       (ProofContext.init thy11) (map fst rec_unique_frees)
-      (List.concat finite_premss @ finite_ctxt_prems @ rec_prems @ rec_prems')
-      (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj rec_unique_concls))
+      (map (augment_sort thy11 fs_cp_sort)
+        (List.concat finite_premss @ finite_ctxt_prems @ rec_prems @ rec_prems'))
+      (augment_sort thy11 fs_cp_sort
+        (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj rec_unique_concls)))
       (fn {prems, context} =>
          let
            val k = length rec_fns;
@@ -1680,10 +1717,12 @@
            val P_ths = map (fn th => th RS mp) (split_conj_thm
              (Goal.prove context
                (map fst (rec_unique_frees'' @ rec_unique_frees')) []
-               (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
-                  (map (fn (((x, y), S), P) => HOLogic.mk_imp
-                    (S $ Free x $ Free y, P $ (Free y)))
-                      (rec_unique_frees'' ~~ rec_unique_frees' ~~ rec_sets ~~ rec_preds))))
+               (augment_sort thy11 fs_cp_sort
+                 (HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj
+                    (map (fn (((x, y), S), P) => HOLogic.mk_imp
+                      (S $ Free x $ Free y, P $ (Free y)))
+                        (rec_unique_frees'' ~~ rec_unique_frees' ~~
+                           rec_sets ~~ rec_preds)))))
                (fn _ =>
                   rtac rec_induct 1 THEN
                   REPEAT ((resolve_tac P_ind_ths THEN_ALL_NEW assume_tac) 1))));
@@ -1722,7 +1761,10 @@
                       chop (length params div 2) (map term_of params);
                     val params' = params1 @ params2;
                     val rec_prems = filter (fn th => case prop_of th of
-                      _ $ (S $ _ $ _) => S mem rec_sets | _ => false) prems';
+                        _ $ p => (case head_of p of
+                          Const (s, _) => s mem rec_set_names
+                        | _ => false)
+                      | _ => false) prems';
                     val fresh_prems = filter (fn th => case prop_of th of
                         _ $ (Const ("Nominal.fresh", _) $ _ $ _) => true
                       | _ $ (Const ("Not", _) $ _) => true
@@ -1731,7 +1773,7 @@
 
                     val _ = warning "step 1: obtaining fresh names";
                     val (freshs1, freshs2, context'') = fold
-                      (obtain_fresh_name (rec_ctxt :: rec_fns @ params')
+                      (obtain_fresh_name (rec_ctxt :: rec_fns' @ params')
                          (List.concat (map snd finite_thss) @
                             finite_ctxt_ths @ rec_prems)
                          rec_fin_supp_thms')
@@ -1802,7 +1844,8 @@
                     val rec_prems' = map (fn th =>
                       let
                         val _ $ (S $ x $ y) = prop_of th;
-                        val k = find_index (equal S) rec_sets;
+                        val Const (s, _) = head_of S;
+                        val k = find_index (equal s) rec_set_names;
                         val pi = rpi1 @ pi2;
                         fun mk_pi z = fold_rev (mk_perm []) pi z;
                         fun eqvt_tac p =
@@ -1988,7 +2031,9 @@
         fun solve rules prems = resolve_tac rules THEN_ALL_NEW
           (resolve_tac prems THEN_ALL_NEW atac)
       in
-        Goal.prove_global thy12 [] prems' concl'
+        Goal.prove_global thy12 []
+          (map (augment_sort thy12 fs_cp_sort) prems')
+          (augment_sort thy12 fs_cp_sort concl')
           (fn {prems, ...} => EVERY
             [rewrite_goals_tac reccomb_defs,
              rtac the1_equality 1,
@@ -1996,7 +2041,7 @@
              resolve_tac rec_intrs 1,
              REPEAT (solve (prems @ rec_total_thms) prems 1)])
       end) (rec_eq_prems ~~
-        DatatypeProp.make_primrecs new_type_names descr' sorts' thy12);
+        DatatypeProp.make_primrecs new_type_names descr' sorts thy12);
 
     val dt_infos = map (make_dt_info pdescr sorts induct reccomb_names rec_thms)
       ((0 upto length descr1 - 1) ~~ descr1 ~~ distinct_thms ~~ inject_thms);