src/Pure/axclass.ML
changeset 36327 c0415cb24a10
parent 36326 85d026788fce
child 36328 4d9deabf6474
--- a/src/Pure/axclass.ML	Sun Apr 25 19:44:47 2010 +0200
+++ b/src/Pure/axclass.ML	Sun Apr 25 21:02:36 2010 +0200
@@ -13,8 +13,8 @@
   val add_arity: thm -> theory -> theory
   val prove_classrel: class * class -> tactic -> theory -> theory
   val prove_arity: string * sort list * sort -> tactic -> theory -> theory
-  val get_info: theory -> class ->
-    {def: thm, intro: thm, axioms: thm list, params: (string * typ) list}
+  type info = {def: thm, intro: thm, axioms: thm list, params: (string * typ) list}
+  val get_info: theory -> class -> info
   val class_intros: theory -> thm list
   val class_of_param: theory -> string -> class option
   val cert_classrel: theory -> class * class -> class * class
@@ -60,17 +60,17 @@
       fold_rev (fn q => if member (op =) ps q then I else add_param pp q) qs ps;
 
 
-(* axclasses *)
+(* axclass info *)
 
-datatype axclass = AxClass of
+type info =
  {def: thm,
   intro: thm,
   axioms: thm list,
   params: (string * typ) list};
 
-type axclasses = axclass Symtab.table * param list;
+type axclasses = info Symtab.table * param list;
 
-fun make_axclass ((def, intro, axioms), params) = AxClass
+fun make_axclass ((def, intro, axioms), params): info =
   {def = def, intro = intro, axioms = axioms, params = params};
 
 fun merge_axclasses pp ((tab1, params1), (tab2, params2)) : axclasses =
@@ -106,7 +106,7 @@
 
 (* setup data *)
 
-structure AxClassData = Theory_Data_PP
+structure Data = Theory_Data_PP
 (
   type T = axclasses * ((instances * inst_params) * (class * class) list);
   val empty = ((Symtab.empty, []), (((Symreltab.empty, Symtab.empty), (Symtab.empty, Symtab.empty)), []));
@@ -114,10 +114,11 @@
   fun merge pp ((axclasses1, ((instances1, inst_params1), diff_merge_classrels1)),
     (axclasses2, ((instances2, inst_params2), diff_merge_classrels2))) =
     let
-      val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2)
-      val diff_merge_classrels = subtract (op =) classrels1 classrels2
-        @ subtract (op =) classrels2 classrels1
-        @ diff_merge_classrels1 @ diff_merge_classrels2
+      val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2);
+      val diff_merge_classrels =
+        subtract (op =) classrels1 classrels2 @
+        subtract (op =) classrels2 classrels1 @
+        diff_merge_classrels1 @ diff_merge_classrels2;
     in
       (merge_axclasses pp (axclasses1, axclasses2),
         ((merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)),
@@ -128,29 +129,23 @@
 
 (* maintain axclasses *)
 
-val get_axclasses = #1 o AxClassData.get;
-val map_axclasses = AxClassData.map o apfst;
-
-val lookup_def = Symtab.lookup o #1 o get_axclasses;
+val get_axclasses = #1 o Data.get;
+val map_axclasses = Data.map o apfst;
 
 fun get_info thy c =
-  (case lookup_def thy c of
-    SOME (AxClass info) => info
+  (case Symtab.lookup (#1 (get_axclasses thy)) c of
+    SOME info => info
   | NONE => error ("No such axclass: " ^ quote c));
 
 fun class_intros thy =
   let
-    fun add_intro c =
-      (case lookup_def thy c of SOME (AxClass {intro, ...}) => cons intro | _ => I);
+    fun add_intro c = (case try (get_info thy) c of SOME {intro, ...} => cons intro | _ => I);
     val classes = Sign.all_classes thy;
   in map (Thm.class_triv thy) classes @ fold add_intro classes [] end;
 
-
-fun get_params thy pred =
+fun all_params_of thy S =
   let val params = #2 (get_axclasses thy);
-  in fold (fn (x, c) => if pred c then cons x else I) params [] end;
-
-fun all_params_of thy S = get_params thy (fn c => Sign.subsort thy (S, [c]));
+  in fold (fn (x, c) => if Sign.subsort thy (S, [c]) then cons x else I) params [] end;
 
 fun class_of_param thy = AList.lookup (op =) (#2 (get_axclasses thy));
 
@@ -159,11 +154,11 @@
 
 fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
 
-val get_instances = #1 o #1 o #2 o AxClassData.get;
-val map_instances = AxClassData.map o apsnd o apfst o apfst;
+val get_instances = #1 o #1 o #2 o Data.get;
+val map_instances = Data.map o apsnd o apfst o apfst;
 
-val get_diff_merge_classrels = #2 o #2 o AxClassData.get;
-val clear_diff_merge_classrels = AxClassData.map (apsnd (apsnd (K [])));
+val get_diff_merge_classrels = #2 o #2 o Data.get;
+val clear_diff_merge_classrels = Data.map (apsnd (apsnd (K [])));
 
 
 fun the_classrel thy (c1, c2) =
@@ -177,26 +172,29 @@
 
 fun put_trancl_classrel ((c1, c2), th) thy =
   let
-    val classrels = fst (get_instances thy)
-    val alg = Sign.classes_of thy
-    val {classes, ...} = alg |> Sorts.rep_algebra
+    val cert = Thm.cterm_of thy;
+    val certT = Thm.ctyp_of thy;
+
+    val classrels = fst (get_instances thy);
+    val classes = #classes (Sorts.rep_algebra (Sign.classes_of thy));
 
     fun reflcl_classrel (c1', c2') =
-      if c1' = c2' then Thm.trivial (Logic.mk_of_class (TVar(("'a",0),[]), c1') |> cterm_of thy)
-      else the_classrel_thm thy (c1', c2')
+      if c1' = c2'
+      then Thm.trivial (cert (Logic.mk_of_class (TVar ((Name.aT, 0), []), c1')))
+      else the_classrel_thm thy (c1', c2');
     fun gen_classrel (c1_pred, c2_succ) =
       let
         val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
-          |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [])))] []
-          |> Thm.close_derivation
-        val prf' = th' |> Thm.proof_of
-      in ((c1_pred, c2_succ), (th',prf')) end
+          |> Drule.instantiate' [SOME (certT (TVar ((Name.aT, 0), [])))] []
+          |> Thm.close_derivation;
+        val prf' = th' |> Thm.proof_of;
+      in ((c1_pred, c2_succ), (th', prf')) end;
 
-    val new_classrels = Library.map_product pair
-        (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
+    val new_classrels =
+      Library.map_product pair (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
       |> filter_out (Symreltab.defined classrels)
-      |> map gen_classrel
-    val needed = length new_classrels > 0
+      |> map gen_classrel;
+    val needed = not (null new_classrels);
   in
     (needed,
      if needed then
@@ -207,13 +205,13 @@
 
 fun complete_classrels thy =
   let
-    val diff_merge_classrels = get_diff_merge_classrels thy
-    val classrels = fst (get_instances thy)
+    val diff_merge_classrels = get_diff_merge_classrels thy;
+    val classrels = fst (get_instances thy);
     val (needed, thy') = (false, thy) |>
       fold (fn c12 => fn (needed, thy) =>
           put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> fst) thy
           |>> (fn b => needed orelse b))
-        diff_merge_classrels
+        diff_merge_classrels;
   in
     if null diff_merge_classrels then NONE
     else thy' |> clear_diff_merge_classrels |> SOME
@@ -246,9 +244,9 @@
       let
         val th1 = (th RS the_classrel_thm thy (c, c1))
           |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names_and_Ss) []
-          |> Thm.close_derivation
-        val prf1 = Thm.proof_of th1
-      in (((th1,thy_name), prf1), c1) end)
+          |> Thm.close_derivation;
+        val prf1 = Thm.proof_of th1;
+      in (((th1, thy_name), prf1), c1) end);
     val arities' = fold (fn (th_thy_prf1, c1) => Symtab.cons_list (t, ((c1, Ss), th_thy_prf1)))
       completions arities;
   in (null completions, arities') end;
@@ -281,24 +279,23 @@
 
 (* maintain instance parameters *)
 
-val get_inst_params = #2 o #1 o #2 o AxClassData.get;
-val map_inst_params = AxClassData.map o apsnd o apfst o apsnd;
+val get_inst_params = #2 o #1 o #2 o Data.get;
+val map_inst_params = Data.map o apsnd o apfst o apsnd;
 
 fun get_inst_param thy (c, tyco) =
-  case Symtab.lookup ((the_default Symtab.empty o Symtab.lookup (fst (get_inst_params thy))) c) tyco
-   of SOME c' => c'
-    | NONE => error ("No instance parameter for constant " ^ quote c
-        ^ " on type constructor " ^ quote tyco);
+  (case Symtab.lookup (the_default Symtab.empty (Symtab.lookup (#1 (get_inst_params thy)) c)) tyco of
+    SOME c' => c'
+  | NONE => error ("No instance parameter for constant " ^ quote c ^ " on type " ^ quote tyco));
 
-fun add_inst_param (c, tyco) inst = (map_inst_params o apfst
-      o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
+fun add_inst_param (c, tyco) inst =
+  (map_inst_params o apfst o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
   #> (map_inst_params o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
 
 val inst_of_param = Symtab.lookup o snd o get_inst_params;
 val param_of_inst = fst oo get_inst_param;
 
-fun inst_thms thy = (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst)
-  (get_inst_params thy) [];
+fun inst_thms thy =
+  (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst) (get_inst_params thy) [];
 
 fun get_inst_tyco consts = try (fst o dest_Type o the_single o Consts.typargs consts);
 
@@ -308,18 +305,20 @@
 fun unoverload_conv thy = MetaSimplifier.rewrite true (inst_thms thy);
 fun overload_conv thy = MetaSimplifier.rewrite true (map Thm.symmetric (inst_thms thy));
 
-fun lookup_inst_param consts params (c, T) = case get_inst_tyco consts (c, T)
- of SOME tyco => AList.lookup (op =) params (c, tyco)
-  | NONE => NONE;
+fun lookup_inst_param consts params (c, T) =
+  (case get_inst_tyco consts (c, T) of
+    SOME tyco => AList.lookup (op =) params (c, tyco)
+  | NONE => NONE);
 
 fun unoverload_const thy (c_ty as (c, _)) =
-  if is_some (class_of_param thy c)
-  then case get_inst_tyco (Sign.consts_of thy) c_ty
-   of SOME tyco => try (param_of_inst thy) (c, tyco) |> the_default c
-    | NONE => c
+  if is_some (class_of_param thy c) then
+    (case get_inst_tyco (Sign.consts_of thy) c_ty of
+      SOME tyco => try (param_of_inst thy) (c, tyco) |> the_default c
+    | NONE => c)
   else c;
 
 
+
 (** instances **)
 
 (* class relations *)
@@ -340,11 +339,8 @@
   cert_classrel thy (pairself (ProofContext.read_class (ProofContext.init thy)) raw_rel)
     handle TYPE (msg, _, _) => error msg;
 
-fun check_shyps_topped th errmsg =
-  let val {shyps, ...} = Thm.rep_thm th
-  in
-    forall null shyps orelse raise Fail errmsg
-  end;
+val shyps_topped = forall null o #shyps o Thm.rep_thm;
+
 
 (* declaration and definition of instances of overloaded constants *)
 
@@ -406,7 +402,7 @@
     val th' = th
       |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [c1])))] []
       |> Drule.unconstrainTs;
-    val _ = check_shyps_topped th' "add_classrel: nontop shyps after unconstrain"
+    val _ = shyps_topped th' orelse raise Fail "add_classrel: nontop shyps after unconstrain";
   in
     thy
     |> Sign.primitive_classrel (c1, c2)
@@ -430,7 +426,7 @@
     val th' = th
       |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names) []
       |> Drule.unconstrainTs;
-    val _ = check_shyps_topped th' "add_arity: nontop shyps after unconstrain"
+    val _ = shyps_topped th' orelse raise Fail "add_arity: nontop shyps after unconstrain";
   in
     thy
     |> fold (snd oo declare_overloaded) missing_params