src/Pure/axclass.ML
changeset 19503 10921826b160
parent 19482 9f11af8f7ef9
child 19511 b4bd790f9373
--- a/src/Pure/axclass.ML	Sat Apr 29 23:16:43 2006 +0200
+++ b/src/Pure/axclass.ML	Sat Apr 29 23:16:45 2006 +0200
@@ -9,10 +9,6 @@
 sig
   val print_axclasses: theory -> unit
   val get_info: theory -> class -> {def: thm, intro: thm, axioms: thm list}
-  val get_instances: theory ->
-   {classes: unit Graph.T,
-    classrel: ((class * class) * thm) list,
-    arities: ((string * sort list * class) * thm) list}
   val class_intros: theory -> thm list
   val params_of: theory -> class -> string list
   val all_params_of: theory -> sort -> string list
@@ -22,6 +18,7 @@
   val add_arity: thm -> theory -> theory
   val prove_classrel: class * class -> tactic -> theory -> theory
   val prove_arity: string * sort list * sort -> tactic -> theory -> theory
+  val of_sort: theory -> typ * sort -> thm list
   val add_axclass: bstring * xstring list -> string list ->
     ((bstring * Attrib.src list) * string list) list -> theory -> class * theory
   val add_axclass_i: bstring * class list -> string list ->
@@ -31,7 +28,6 @@
 structure AxClass: AX_CLASS =
 struct
 
-
 (** theory data **)
 
 (* class parameters (canonical order) *)
@@ -74,21 +70,23 @@
 datatype instances = Instances of
  {classes: unit Graph.T,                 (*raw relation -- no closure!*)
   classrel: ((class * class) * thm) list,
-  arities: ((string * sort list * class) * thm) list};
+  arities: ((class * sort list) * thm) list Symtab.table,
+  types: (class * thm) list Typtab.table};
 
-fun make_instances (classes, classrel, arities) =
-  Instances {classes = classes, classrel = classrel, arities = arities};
+fun make_instances (classes, classrel, arities, types) =
+  Instances {classes = classes, classrel = classrel, arities = arities, types = types};
 
-fun map_instances f (Instances {classes, classrel, arities}) =
-  make_instances (f (classes, classrel, arities));
+fun map_instances f (Instances {classes, classrel, arities, types}) =
+  make_instances (f (classes, classrel, arities, types));
 
 fun merge_instances
-   (Instances {classes = classes1, classrel = classrel1, arities = arities1},
-    Instances {classes = classes2, classrel = classrel2, arities = arities2}) =
+   (Instances {classes = classes1, classrel = classrel1, arities = arities1, types = types1},
+    Instances {classes = classes2, classrel = classrel2, arities = arities2, types = types2}) =
   make_instances
    (Graph.merge (K true) (classes1, classes2),
     merge (eq_fst op =) (classrel1, classrel2),
-    merge (eq_fst op =) (arities1, arities2));
+    Symtab.join (K (merge (eq_fst op =))) (arities1, arities2),
+    Typtab.join (K (merge (eq_fst op =))) (types1, types2));
 
 
 (* data *)
@@ -96,19 +94,20 @@
 structure AxClassData = TheoryDataFun
 (struct
   val name = "Pure/axclass";
-  type T = axclasses * instances;
-  val empty : T = ((Symtab.empty, []), make_instances (Graph.empty, [], []));
-  val copy = I;
-  val extend = I;
-  fun merge pp ((axclasses1, instances1), (axclasses2, instances2)) =
-    (merge_axclasses pp (axclasses1, axclasses2), merge_instances (instances1, instances2));
+  type T = axclasses * instances ref;
+  val empty : T =
+    ((Symtab.empty, []), ref (make_instances (Graph.empty, [], Symtab.empty, Typtab.empty)));
+  fun copy (axclasses, ref instances) : T = (axclasses, ref instances);
+  val extend = copy;
+  fun merge pp ((axclasses1, ref instances1), (axclasses2, ref instances2)) =
+    (merge_axclasses pp (axclasses1, axclasses2), ref (merge_instances (instances1, instances2)));
   fun print _ _ = ();
 end);
 
 val _ = Context.add_setup AxClassData.init;
 
 
-(* lookup classes *)
+(* classes *)
 
 val lookup_info = Symtab.lookup o #1 o #1 o AxClassData.get;
 
@@ -125,7 +124,7 @@
   in map (Thm.class_triv thy) classes @ fold add_intro classes [] end;
 
 
-(* lookup parameters *)
+(* parameters *)
 
 fun get_params thy pred =
   let val params = #2 (#1 (AxClassData.get thy))
@@ -135,6 +134,30 @@
 fun all_params_of thy S = get_params thy (fn c => Sign.subsort thy (S, [c]));
 
 
+(* instances *)
+
+val get_instances = AxClassData.get #> (fn (_, ref (Instances insts)) => insts);
+
+fun store_instance f thy (x, th) =
+  let
+    val th' = Drule.standard' th;
+    val _ = change (#2 (AxClassData.get thy)) (map_instances (f (x, th')));
+  in th' end;
+
+val store_classrel = store_instance (fn ((c1, c2), th) => fn (classes, classrel, arities, types) =>
+  (classes
+    |> Graph.default_node (c1, ())
+    |> Graph.default_node (c2, ())
+    |> Graph.add_edge (c1, c2),
+    insert (eq_fst op =) ((c1, c2), th) classrel, arities, types));
+
+val store_arity = store_instance (fn ((t, Ss, c), th) => fn (classes, classrel, arities, types) =>
+  (classes, classrel, arities |> Symtab.insert_list (eq_fst op =) (t, ((c, Ss), th)), types));
+
+val store_type = store_instance (fn ((T, c), th) => fn (classes, classrel, arities, types) =>
+  (classes, classrel, arities, types |> Typtab.insert_list (eq_fst op =) (T, (c, th))));
+
+
 (* print_axclasses *)
 
 fun print_axclasses thy =
@@ -154,10 +177,7 @@
 
 
 
-(** instances **)
-
-val get_instances = AxClassData.get #> (fn (_, Instances insts) => insts);
-
+(** instance proofs **)
 
 (* class relations *)
 
@@ -185,28 +205,18 @@
     val prop = Drule.plain_prop_of (Thm.transfer thy th);
     val rel = Logic.dest_classrel prop handle TERM _ => err ();
     val (c1, c2) = cert_classrel thy rel handle TYPE _ => err ();
-  in
-    thy
-    |> Theory.add_classrel_i [(c1, c2)]
-    |> AxClassData.map (apsnd (map_instances (fn (classes, classrel, arities) =>
-        (classes
-            |> Graph.default_node (c1, ())
-            |> Graph.default_node (c2, ())
-            |> Graph.add_edge (c1, c2),
-          ((c1, c2), th) :: classrel, arities))))
-  end;
+    val thy' = thy |> Theory.add_classrel_i [(c1, c2)];
+    val _ = store_classrel thy' ((c1, c2), Drule.unconstrainTs th);
+  in thy' end;
 
 fun add_arity th thy =
   let
     val prop = Drule.plain_prop_of (Thm.transfer thy th);
     val (t, Ss, c) = Logic.dest_arity prop handle TERM _ =>
       raise THM ("add_arity: malformed type arity", 0, [th]);
-  in
-    thy
-    |> Theory.add_arities_i [(t, Ss, [c])]
-    |> AxClassData.map (apsnd (map_instances (fn (classes, classrel, arities) =>
-      (classes, classrel, ((t, Ss, c), th) :: arities))))
-  end;
+    val thy' = thy |> Theory.add_arities_i [(t, Ss, [c])];
+    val _ = store_arity thy' ((t, Ss, c), Drule.unconstrainTs th);
+  in thy' end;
 
 
 (* tactical proofs *)
@@ -230,6 +240,83 @@
   in fold add_arity ths thy end;
 
 
+(* derived instances -- cached *)
+
+fun derive_classrel thy (c1, c2) =
+  let
+    val {classes, classrel, ...} = get_instances thy;
+    val lookup = AList.lookup (op =) classrel;
+    fun derive [c, c'] = the (lookup (c, c'))
+      | derive (c :: c' :: cs) = derive [c, c'] RS derive (c' :: cs);
+  in
+    (case lookup (c1, c2) of
+      SOME rule => rule
+    | NONE =>
+        (case Graph.find_paths classes (c1, c2) of
+          [] => error ("Cannot derive class relation " ^ Sign.string_of_classrel thy [c1, c2])
+        | path :: _ => store_classrel thy ((c1, c2), derive path)))
+  end;
+
+fun weaken_subclass thy (c1, th) c2 =
+  if c1 = c2 then th
+  else th RS derive_classrel thy (c1, c2);
+
+fun weaken_subsort thy S1 S2 = S2 |> map (fn c2 =>
+  (case S1 |> find_first (fn (c1, _) => Sign.subsort thy ([c1], [c2])) of
+    SOME c1 => weaken_subclass thy c1 c2
+  | NONE => error ("Cannot derive subsort relation " ^
+      Sign.string_of_sort thy (map #1 S1) ^ " < " ^ Sign.string_of_sort thy S2)));
+
+fun apply_arity thy t dom c =
+  let
+    val {arities, ...} = get_instances thy;
+    val subsort = Sign.subsort thy;
+    val Ss = map (map #1) dom;
+  in
+    (case Symtab.lookup_list arities t |> find_first (fn ((c', Ss'), _) =>
+        subsort ([c'], [c]) andalso ListPair.all subsort (Ss, Ss')) of
+      SOME ((c', Ss'), rule) =>
+        weaken_subclass thy (c', rule OF flat (map2 (weaken_subsort thy) dom Ss')) c
+    | NONE => error ("Cannot derive type arity " ^ Sign.string_of_arity thy (t, Ss, [c])))
+  end;
+
+fun derive_type thy hyps =
+  let
+    fun derive (Type (a, Ts)) S =
+          let val Ss = Sign.arity_sorts thy a S
+          in map (apply_arity thy a (map2 (fn T => fn S => S ~~ derive T S) Ts Ss)) S end
+      | derive (TFree (a, [])) S =
+          weaken_subsort thy (the_default [] (AList.lookup (op =) hyps a)) S
+      | derive T _ = error ("Illegal occurrence of type variable " ^
+          setmp show_sorts true (Sign.string_of_typ thy) T);
+  in derive end;
+
+fun of_sort thy (typ, sort) =
+  let
+    fun lookup () = AList.lookup (op =) (Typtab.lookup_list (#types (get_instances thy)) typ);
+    val sort' = filter (is_none o lookup ()) sort;
+    val _ = conditional (not (null sort')) (fn () =>
+      let
+        val vars = Term.fold_atyps (insert (op =)) typ [];
+        val renaming =
+          map2 (fn T => fn a => (T, (a, case T of TFree (_, S) => S | TVar (_, S) => S)))
+            vars (Term.invent_names [] "'a" (length vars));
+        val typ' = typ |> Term.map_atyps
+          (fn T => TFree (#1 (the (AList.lookup (op =) renaming T)), []));
+
+        val hyps = renaming |> map (fn (_, (a, S)) => (a, S ~~ (S |> map (fn c =>
+          Thm.assume (Thm.cterm_of thy (Logic.mk_inclass (TFree (a, []), c)))))));
+        val inst = renaming |> map (fn (T, (a, S)) =>
+          pairself (Thm.ctyp_of thy) (TVar ((a, 0), S), T));
+
+        val ths =
+          derive_type thy hyps typ' sort'
+          |> map (Thm.instantiate (inst, []));
+        val _ = map (store_type thy) (map (pair typ) sort' ~~ ths);
+      in () end);
+  in map (the o lookup ()) sort end;
+
+
 
 (** axclass definitions **)
 
@@ -286,8 +373,9 @@
       def_thy
       |> PureThy.note_thmss_qualified "" bconst
         [((introN, []), [([Drule.standard raw_intro], [])]),
-         ((axiomsN, []), [(map (fn th => Drule.standard (class_triv RS th)) raw_axioms, [])])]
-      ||> fold (fn th => add_classrel (Drule.standard' (class_triv RS th))) raw_classrel;
+         ((axiomsN, []), [(map (fn th => Drule.standard (class_triv RS th)) raw_axioms, [])])];
+    val _ = map (store_classrel facts_thy)
+      (map (pair class) super ~~ map Drule.standard raw_classrel);
 
 
     (* result *)
@@ -297,9 +385,9 @@
       |> Sign.add_path bconst
       |> PureThy.note_thmss_i "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
       |> Sign.restore_naming facts_thy
-      |> AxClassData.map (apfst (fn (is, ps) =>
-        (Symtab.update (class, make_axclass (def, intro, axioms)) is,
-          fold (fn x => add_param pp (x, class)) params ps)));
+      |> AxClassData.map (apfst (fn (axclasses, parameters) =>
+        (Symtab.update (class, make_axclass (def, intro, axioms)) axclasses,
+          fold (fn x => add_param pp (x, class)) params parameters)));
 
   in (class, result_thy) end;