src/HOL/Nominal/nominal_inductive.ML
changeset 22730 8bcc8809ed3b
parent 22544 549615dcd4f2
child 22755 e268f608669a
--- a/src/HOL/Nominal/nominal_inductive.ML	Thu Apr 19 16:27:53 2007 +0200
+++ b/src/HOL/Nominal/nominal_inductive.ML	Thu Apr 19 16:38:59 2007 +0200
@@ -8,8 +8,8 @@
 
 signature NOMINAL_INDUCTIVE =
 sig
-  val nominal_inductive: string -> (string * string list) list -> theory -> Proof.state
-  val equivariance: string -> theory -> theory
+  val prove_strong_ind: string -> (string * string list) list -> theory -> Proof.state
+  val prove_eqvt: string -> string list -> theory -> theory
 end
 
 structure NominalInductive : NOMINAL_INDUCTIVE =
@@ -49,9 +49,11 @@
   | add_binders thy i (Abs (_, _, t)) bs = add_binders thy (i + 1) t bs
   | add_binders thy i _ bs = bs;
 
-fun prove_strong_ind raw_induct names avoids thy =
+fun prove_strong_ind s avoids thy =
   let
     val ctxt = ProofContext.init thy;
+    val ({names, ...}, {raw_induct, ...}) =
+      InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
     val induct_cases = map fst (fst (RuleCases.get (the
       (InductAttrib.lookup_inductS ctxt (hd names)))));
     val raw_induct' = Logic.unvarify (prop_of raw_induct);
@@ -314,10 +316,24 @@
       (map (map (rpair [])) vc_compat)
   end;
 
-fun prove_eqvt names raw_induct intrs thy =
+fun prove_eqvt s xatoms thy =
   let
     val ctxt = ProofContext.init thy;
-    val atoms = NominalAtoms.atoms_of thy;
+    val ({names, ...}, {raw_induct, intrs, ...}) =
+      InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
+    val atoms' = NominalAtoms.atoms_of thy;
+    val atoms =
+      if null xatoms then atoms' else
+      let val atoms = map (Sign.intern_type thy) xatoms
+      in
+        (case duplicates op = atoms of
+             [] => ()
+           | xs => error ("Duplicate atoms: " ^ commas xs);
+         case atoms \\ atoms' of
+             [] => ()
+           | xs => error ("No such atoms: " ^ commas xs);
+         atoms)
+      end;
     val eqvt_ss = HOL_basic_ss addsimps NominalThmDecls.get_eqvt_thms thy;
     val t = Logic.unvarify (concl_of raw_induct);
     val pi = Name.variant (add_term_names (t, [])) "pi";
@@ -363,20 +379,6 @@
       Theory.parent_path) (names ~~ transp thss) thy
   end;
 
-fun gen_nominal_inductive f s avoids thy =
-  let
-    val ctxt = ProofContext.init thy;
-    val ({names, ...}, {raw_induct, intrs, ...}) =
-      InductivePackage.the_inductive ctxt (Sign.intern_const thy s);
-  in
-    thy |>
-    prove_eqvt names raw_induct intrs |>
-    f raw_induct names avoids
-  end;
-
-val nominal_inductive = gen_nominal_inductive prove_strong_ind;
-fun equivariance s = gen_nominal_inductive (K (K (K I))) s [];
-
 
 (* outer syntax *)
 
@@ -387,12 +389,13 @@
     "prove equivariance and strong induction theorem for inductive predicate involving nominal datatypes" K.thy_goal
     (P.name -- Scan.optional (P.$$$ "avoids" |-- P.and_list1 (P.name --
       (P.$$$ ":" |-- Scan.repeat1 P.name))) [] >> (fn (name, avoids) =>
-        Toplevel.print o Toplevel.theory_to_proof (nominal_inductive name avoids)));
+        Toplevel.print o Toplevel.theory_to_proof (prove_strong_ind name avoids)));
 
 val equivarianceP =
   OuterSyntax.command "equivariance"
     "prove equivariance for inductive predicate involving nominal datatypes" K.thy_decl
-    (P.name >> (Toplevel.theory o equivariance));
+    (P.name -- Scan.optional (P.$$$ "[" |-- P.list1 P.name --| P.$$$ "]") [] >>
+      (fn (name, atoms) => Toplevel.theory (prove_eqvt name atoms)));
 
 val _ = OuterSyntax.add_keywords ["avoids"];
 val _ = OuterSyntax.add_parsers [nominal_inductiveP, equivarianceP];