code cleanup, instance_subsort now working
authorhaftmann
Mon, 14 Aug 2006 13:46:20 +0200
changeset 20385 2f52b5aba086
parent 20384 049d955cf716
child 20386 d1cbe5aa6bf2
code cleanup, instance_subsort now working
src/Pure/Tools/class_package.ML
--- a/src/Pure/Tools/class_package.ML	Mon Aug 14 13:46:19 2006 +0200
+++ b/src/Pure/Tools/class_package.ML	Mon Aug 14 13:46:20 2006 +0200
@@ -25,12 +25,9 @@
   val instance_sort_i: class * sort -> theory -> Proof.state
   val prove_instance_sort: tactic -> class * sort -> theory -> theory
 
-  val intern_class: theory -> xstring -> class
-  val intern_sort: theory -> sort -> sort
-  val extern_class: theory -> class -> xstring
-  val extern_sort: theory -> sort -> sort
   val certify_class: theory -> class -> class
   val certify_sort: theory -> sort -> sort
+  val read_class: theory -> xstring -> class
   val read_sort: theory -> string -> sort
   val operational_sort_of: theory -> sort -> sort
   val operational_algebra: theory -> Sorts.algebra
@@ -49,33 +46,13 @@
   datatype classlookup = Instance of (class * string) * classlookup list list
                        | Lookup of class list * (string * (int * int))
   val sortcontext_of_typ: theory -> typ -> sortcontext
-  val sortlookup: theory -> sort * typ -> classlookup list
-  val sortlookups_const: theory -> string * typ -> classlookup list list
+  val sortlookup: theory -> typ * sort -> classlookup list
 end;
 
 structure ClassPackage : CLASS_PACKAGE =
 struct
 
 
-(* auxiliary *)
-
-fun instantiations_of thy (ty, ty') =
-  let
-(*     val _ = writeln "A";  *)
-    val vartab = typ_tvars ty;
-(*     val _ = writeln "B";  *)
-    fun prep_vartab (v, (_, ty)) =
-      case (the o AList.lookup (op =) vartab) v
-       of [] => NONE
-        | sort => SOME ((v, sort), ty);
-(*     val _ = writeln "C";  *)
-  in case try (Sign.typ_match thy (ty, ty')) Vartab.empty
-   of NONE => ((*writeln "D";*)NONE)
-    | SOME vartab =>
-        ((*writeln "E";*)SOME ((map_filter prep_vartab o Vartab.dest) vartab))
-  end;
-
-
 (* theory data *)
 
 datatype class_data = ClassData of {
@@ -83,7 +60,7 @@
   name_axclass: string,
   var: string,
   consts: (string * (string * typ)) list,
-    (*locale parameter ~> toplevel constant*)
+    (*locale parameter ~> toplevel theory constant*)
   propnames: string list
 } * thm list Symtab.table;
 
@@ -161,7 +138,7 @@
   let
     fun ancestry class anc =
       anc
-      |> cons class
+      |> insert (op =) class
       |> fold ancestry (the_superclasses thy class);
   in fold ancestry classes [] end;
 
@@ -216,7 +193,7 @@
   );
 
 
-(* name handling *)
+(* certification and reading *)
 
 fun certify_class thy class =
   (fn class => (the_class_data thy class; class)) (Sign.certify_class thy class);
@@ -224,18 +201,9 @@
 fun certify_sort thy sort =
   map (fn class => (the_class_data thy class; class)) (Sign.certify_sort thy sort);
 
-fun intern_class thy =
+fun read_class thy =
   certify_class thy o Sign.intern_class thy;
 
-fun intern_sort thy =
-  certify_sort thy o Sign.intern_sort thy;
-
-fun extern_class thy =
-  Sign.extern_class thy o certify_class thy;
-
-fun extern_sort thy =
-  Sign.extern_sort thy o certify_sort thy;
-
 fun read_sort thy =
   certify_sort thy o Sign.read_sort thy;
 
@@ -315,7 +283,7 @@
     val {intro, axioms, ...} = AxClass.get_definition thy' c;
   in ((c, (intro, axioms)), thy') end;
 
-(* FIXME proper locale interface *)
+(*FIXME proper locale interface*)
 fun prove_interpretation_i (prfx, atts) expr insts tac thy =
   let
     fun ad_hoc_term (Const (c, ty)) =
@@ -352,10 +320,10 @@
         fun extract_classvar ((c, ty), _) w =
           (case add_typ_tfrees (ty, [])
            of [(v, _)] => (case w
-               of SOME u => if u = v then w else error ("additonal type variable in type signature of constant " ^ quote c)
+               of SOME u => if u = v then w else error ("Additonal type variable in type signature of constant " ^ quote c)
                 | NONE => SOME v)
-            | [] => error ("no type variable in type signature of constant " ^ quote c)
-            | _ => error ("more than one type variable in type signature of constant " ^ quote c));
+            | [] => error ("No type variable in type signature of constant " ^ quote c)
+            | _ => error ("More than one type variable in type signature of constant " ^ quote c));
         val consts1 =
           Locale.parameters_of thy name_locale
           |> map (apsnd Syntax.unlocalize_mixfix)
@@ -411,7 +379,7 @@
 
 in
 
-val class = gen_class (Locale.add_locale false) intern_class;
+val class = gen_class (Locale.add_locale false) read_class;
 val class_i = gen_class (Locale.add_locale_i false) certify_class;
 
 end; (* local *)
@@ -424,14 +392,9 @@
     val ((c, ty), _) = Sign.cert_def (Sign.pp thy) t;
     val atts = map (prep_att thy) raw_atts;
     val insts = (Consts.typargs (Sign.consts_of thy) (c, ty))
-    fun flat_typ (Type (tyco, tys)) = tyco :: maps flat_typ tys
-      | flat_typ _ = [];
     val name = case raw_name
-     of "" => let
-            val tycos = maps flat_typ insts;
-            val names = map NameSpace.base (c :: tycos);
-          in Thm.def_name (space_implode "_" names) end
-      | _ => raw_name;
+     of "" => NONE
+      | _ => SOME raw_name;
   in (c, (insts, ((name, t), atts))) end;
 
 fun read_def thy = gen_read_def thy Attrib.attribute read_axm;
@@ -472,22 +435,26 @@
         val ty = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) (Name.names Name.context "'a" asorts))
       in maps (get_consts_class tyco ty) (the_ancestry theory sort) end;
     val cs = maps get_consts_sort arities;
+    fun mk_typnorm thy (ty, ty_sc) =
+      case try (Sign.typ_match thy (Logic.varifyT ty_sc, ty)) Vartab.empty
+       of SOME env => SOME (Logic.varifyT #> Envir.typ_subst_TVars env #> Logic.unvarifyT)
+        | NONE => NONE;
     fun read_defs defs cs thy_read =
       let
         fun read raw_def cs =
           let
-            val (c, (inst, ((_, t), atts))) = read_def thy_read raw_def;
-            val ty = Logic.varifyT (Consts.instance (Sign.consts_of thy_read) (c, inst));
+            val (c, (inst, ((name_opt, t), atts))) = read_def thy_read raw_def;
+            val ty = Consts.instance (Sign.consts_of thy_read) (c, inst);
             val ((tyco, class), ty') = case AList.lookup (op =) cs c
              of NONE => error ("superfluous definition for constant " ^ quote c)
               | SOME class_ty => class_ty;
-            val name = Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco);
-            val t' = case instantiations_of thy_read (ty, ty')
+            val name = case name_opt
+             of NONE => Thm.def_name (NameSpace.base c ^ "_" ^ NameSpace.base tyco)
+              | SOME name => name;
+            val t' = case mk_typnorm thy_read (ty', ty)
              of NONE => error ("superfluous definition for constant " ^
                   quote c ^ "::" ^ Sign.string_of_typ thy_read ty)
-              | SOME insttab =>
-                  map_term_types
-                    (Logic.unvarifyT o Term.instantiateT insttab o Logic.varifyT) t
+              | SOME norm => map_term_types norm t
           in (((class, tyco), ((name, t'), atts)), AList.delete (op =) c cs) end;
       in fold_map read defs cs end;
     val (defs, _) = assume_arities_thy theory arities_pair (read_defs raw_defs cs);
@@ -497,7 +464,7 @@
       in
         thy
         |> Sign.add_const_constraint_i (c, NONE)
-        |> pair (c, Logic.legacy_unvarifyT ty)
+        |> pair (c, Logic.unvarifyT ty)
       end;
     fun add_defs defs thy =
       thy
@@ -519,7 +486,7 @@
       |> fold Sign.add_const_constraint_i (map (apsnd SOME) cs);
   in
     theory
-    |> fold_map get_remove_contraint (map fst cs)
+    |> fold_map get_remove_contraint (map fst cs |> distinct (op =))
     ||>> add_defs defs
     |-> (fn (cs, def_thms) =>
        fold add_inst_def def_thms
@@ -545,13 +512,13 @@
 
 local
 
-(* FIXME proper locale interface *)
 fun prove_interpretation_in tac after_qed (name, expr) thy =
   thy
   |> Locale.interpretation_in_locale (ProofContext.theory after_qed) (name, expr)
   |> Proof.global_terminal_proof (Method.Basic (fn _ => Method.SIMPLE_METHOD tac), NONE)
   |> ProofContext.theory_of;
 
+(*FIXME very ad-hoc, needs proper locale interface*)
 fun gen_instance_sort prep_class prep_sort do_proof (raw_class, raw_sort) theory =
   let
     val class = prep_class theory raw_class;
@@ -582,7 +549,7 @@
     |> do_proof after_qed (loc_name, loc_expr)
   end;
 
-fun instance_sort' do_proof = gen_instance_sort intern_class read_sort do_proof;
+fun instance_sort' do_proof = gen_instance_sort read_class read_sort do_proof;
 fun instance_sort_i' do_proof = gen_instance_sort certify_class certify_sort do_proof;
 val setup_proof = Locale.interpretation_in_locale o ProofContext.theory;
 val tactic_proof = prove_interpretation_in;
@@ -601,15 +568,13 @@
 type sortcontext = (string * sort) list;
 
 fun sortcontext_of_typ thy ty = fold_atyps
-  (fn TFree (a, S) =>
+  (fn TFree (v, S) =>
     (case operational_sort_of thy S of
       [] => I
-    | S' => insert (op =) (a, S'))
-  | T => raise TYPE ("Illegal schematic type variable", [T], [])) ty []
-  |> sort_wrt fst;   (* FIXME really required?!? *)
+    | S' => insert (op =) (v, S'))) (Type.no_tvars ty) [];
 
 datatype classlookup = Instance of (class * string) * classlookup list list
-                     | Lookup of class list * (string * (int * int))
+                    | Lookup of class list * (string * (int * int));
 
 fun pretty_lookup' (Instance ((class, tyco), lss)) =
       (Pretty.block o Pretty.breaks) (
@@ -621,7 +586,7 @@
         [Pretty.str (v ^ "!" ^ string_of_int i ^ "/" ^ string_of_int j)])
 and pretty_lookup ls = (Pretty.enum "," "(" ")" o map pretty_lookup') ls;
 
-fun sortlookup thy (sort_decl, typ_ctxt) =
+fun sortlookup thy (typ_ctxt, sort_decl) =
   let
     val pp = Sign.pp thy;
     val algebra = Sorts.project_algebra pp (is_operational_class thy)
@@ -631,38 +596,14 @@
       | classrel (Instance ((_, tyco), lss), _) class =
           Instance ((class, tyco), lss);
     fun constructor tyco lss class =
-      Instance ((class, tyco), (map o map) fst lss)
+      Instance ((class, tyco), (map o map) fst lss);
     fun variable (TFree (v, sort)) =
           map_index (fn (n, class) => (Lookup ([], (v, (n, length sort))), class))
-            (operational_sort_of thy sort)
-      | variable (TVar _) = error "TVar encountered while deriving sortlookup";
+            (operational_sort_of thy sort);
   in
     Sorts.of_sort_derivation pp algebra
       {classrel = classrel, constructor = constructor, variable = variable}
-      (typ_ctxt, operational_sort_of thy sort_decl)
-  end;
-
-fun sortlookups_const thy (c, typ_ctxt) =
-  let
-(*     val _ = writeln c  *)
-    val typ_decl = case AxClass.class_of_param thy c
-     of NONE => Sign.the_const_type thy c
-      | SOME class => case the_consts_sign thy class of (v, cs) =>
-          (Logic.legacy_varifyT o subst_clsvar v (TFree (v, [class])))
-          ((the o AList.lookup (op =) cs) c)
-(*     val _ = writeln "DECL"  *)
-(*     val _ = (writeln o Display.raw_string_of_typ) typ_decl;  *)
-(*     val _ = writeln "CTXT"  *)
-(*     val _ = (writeln o Display.raw_string_of_typ) typ_ctxt;  *)
-(*     val _ = writeln "(0)"  *)
-  in
-    instantiations_of thy (typ_decl, typ_ctxt)
-    |> the
-(*     |> tap (fn _ => writeln "(1)")  *)
-    |> map (fn ((_, sort), ty) => sortlookup thy (sort, ty))
-(*     |> tap (fn _ => writeln "(2)")  *)
-    |> filter_out null
-(*     |> tap (fn _ => writeln "(3)")  *)
+      (Type.no_tvars typ_ctxt, operational_sort_of thy sort_decl)
   end;
 
 
@@ -675,32 +616,23 @@
 
 in
 
-val (classK, instanceK) = ("class", "instance")
+val (classK, instanceK, print_classesK) = ("class", "instance", "print_classes")
 
-fun wrap_add_instance_sort ((class, sort), use_interp) thy =
-  if use_interp
-    andalso forall (is_some o lookup_class_data thy) (Sign.read_sort thy sort)
-  then
-    instance_sort (class, sort) thy
-  else
-    axclass_instance_sort (class, sort) thy;
+fun wrap_add_instance_sort (class, sort) thy =
+  (if forall (is_some o lookup_class_data thy) (Sign.read_sort thy sort)
+  then instance_sort else axclass_instance_sort) (class, sort) thy;
 
-val parse_inst =
+val class_subP = P.name -- Scan.repeat (P.$$$ "+" |-- P.name) >> (op ::);
+val class_bodyP = P.!!! (Scan.repeat1 P.context_element);
+
+val inst =
   (Scan.optional (P.$$$ "(" |-- P.!!! (P.list1 P.sort --| P.$$$ ")")) [] -- P.xname --| P.$$$ "::" -- P.sort)
     >> (fn ((asorts, tyco), sort) => ((tyco, asorts), sort))
   || (P.xname --| P.$$$ "::" -- P.!!! P.arity)
     >> (fn (tyco, (asorts, sort)) => ((tyco, asorts), sort));
 
-val locale_val =
-  (P.locale_expr --
-    Scan.optional (P.$$$ "+" |-- P.!!! (Scan.repeat1 P.context_element)) [] ||
-  Scan.repeat1 P.context_element >> pair Locale.empty);
-
-val class_subP = P.name -- Scan.repeat (P.$$$ "+" |-- P.name) >> (op ::);
-val class_bodyP = P.!!! (Scan.repeat1 P.context_element);
-
 val classP =
-  OuterSyntax.command classK "operational type classes" K.thy_decl (
+  OuterSyntax.command classK "define operational type class" K.thy_decl (
     P.name --| P.$$$ "="
     -- (
       class_subP --| P.$$$ "+" -- class_bodyP
@@ -711,13 +643,18 @@
 
 val instanceP =
   OuterSyntax.command instanceK "prove type arity or subclass relation" K.thy_goal ((
-      P.xname -- ((P.$$$ "\\<subseteq>" || P.$$$ "<") |-- P.!!! P.xname) -- P.opt_keyword "open" >> wrap_add_instance_sort
-      || P.opt_thm_name ":" -- (P.and_list1 parse_inst -- Scan.repeat (P.opt_thm_name ":" -- P.prop))
+      P.xname -- ((P.$$$ "\\<subseteq>" || P.$$$ "<") |-- P.!!! P.xname) >> wrap_add_instance_sort
+      || P.opt_thm_name ":" -- (P.and_list1 inst -- Scan.repeat (P.opt_thm_name ":" -- P.prop))
            >> (fn (("", []), ([((tyco, asorts), sort)], [])) => axclass_instance_arity I [(tyco, asorts, sort)]
                 | (natts, (arities, defs)) => instance_arity arities natts defs)
     ) >> (Toplevel.print oo Toplevel.theory_to_proof));
 
-val _ = OuterSyntax.add_parsers [classP, instanceP];
+val print_classesP =
+  OuterSyntax.improper_command print_classesK "print classes of this theory" K.diag
+    (Scan.succeed (Toplevel.no_timing o Toplevel.unknown_theory
+      o Toplevel.keep (print_classes o Toplevel.theory_of)));
+
+val _ = OuterSyntax.add_parsers [classP, instanceP, print_classesP];
 
 end; (* local *)