added domain_error;
authorwenzelm
Tue, 02 May 2006 00:20:38 +0200
changeset 19529 690861f93d2b
parent 19528 7fbac32cded0
child 19530 486dd4b07188
added domain_error; added of_sort_derivation; tuned;
src/Pure/sorts.ML
--- a/src/Pure/sorts.ML	Tue May 02 00:20:37 2006 +0200
+++ b/src/Pure/sorts.ML	Tue May 02 00:20:38 2006 +0200
@@ -3,6 +3,14 @@
     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
 
 The order-sorted algebra of type classes.
+
+Classes denote (possibly empty) collections of types that are
+partially ordered by class inclusion. They are represented
+symbolically by strings.
+
+Sorts are intersections of finitely many classes. They are represented
+by lists of classes.  Normal forms of sorts are sorted lists of
+minimal classes (wrt. current class inclusion).
 *)
 
 signature SORTS =
@@ -26,37 +34,29 @@
   val sorts_le: classes -> sort list * sort list -> bool
   val inter_sort: classes -> sort * sort -> sort
   val norm_sort: classes -> sort -> sort
-  val of_sort: classes * arities -> typ * sort -> bool
-  exception DOMAIN of string * class
-  val mg_domain: classes * arities -> string -> sort -> sort list  (*exception DOMAIN*)
-  val witness_sorts: classes * arities -> string list ->
-    sort list -> sort list -> (typ * sort) list
   val add_arities: Pretty.pp -> classes -> string * (class * sort list) list -> arities -> arities
   val rebuild_arities: Pretty.pp -> classes -> arities -> arities
   val merge_arities: Pretty.pp -> classes -> arities * arities -> arities
   val add_class: Pretty.pp -> class * class list -> classes -> classes
   val add_classrel: Pretty.pp -> class * class -> classes -> classes
   val merge_classes: Pretty.pp -> classes * classes -> classes
+  exception DOMAIN of string * class
+  val domain_error: Pretty.pp -> string * class -> 'a
+  val mg_domain: classes * arities -> string -> sort -> sort list  (*exception DOMAIN*)
+  val of_sort: classes * arities -> typ * sort -> bool
+  val of_sort_derivation: Pretty.pp -> classes * arities ->
+    {classrel: 'a * class -> class -> 'a,
+     constructor: string -> ('a * class) list list -> class -> 'a,
+     variable: typ -> ('a * class) list} -> typ * sort -> 'a list
+  val witness_sorts: classes * arities -> string list ->
+    sort list -> sort list -> (typ * sort) list
 end;
 
 structure Sorts: SORTS =
 struct
 
 
-(** type classes and sorts **)
-
-(*
-  Classes denote (possibly empty) collections of types that are
-  partially ordered by class inclusion. They are represented
-  symbolically by strings.
-
-  Sorts are intersections of finitely many classes. They are
-  represented by lists of classes.  Normal forms of sorts are sorted
-  lists of minimal classes (wrt. current class inclusion).
-*)
-
-
-(* ordered lists of sorts *)
+(** ordered lists of sorts **)
 
 val eq_set = OrdList.eq_set Term.sort_ord;
 val op union = OrdList.union Term.sort_ord;
@@ -82,7 +82,8 @@
   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
 
 
-(* order-sorted algebra *)
+
+(** order-sorted algebra **)
 
 (*
   classes: graph representing class declarations together with proper
@@ -100,17 +101,14 @@
 type arities = (class * (class * sort list)) list Symtab.table;
 
 
-
-(** equality and inclusion **)
-
-(* classes *)
+(* class relations *)
 
 fun class_eq (_: classes) (c1, c2:class) = c1 = c2;
 val class_less: classes -> class * class -> bool = Graph.is_edge;
 fun class_le classes (c1, c2) = c1 = c2 orelse class_less classes (c1, c2);
 
 
-(* sorts *)
+(* sort relations *)
 
 fun sort_le classes (S1, S2) =
   forall (fn c2 => exists (fn c1 => class_le classes (c1, c2)) S1) S2;
@@ -122,18 +120,7 @@
   sort_le classes (S1, S2) andalso sort_le classes (S2, S1);
 
 
-(* normal forms of sorts *)
-
-fun minimal_class classes S c =
-  not (exists (fn c' => class_less classes (c', c)) S);
-
-fun norm_sort _ [] = []
-  | norm_sort _ (S as [_]) = S
-  | norm_sort classes S = sort_distinct string_ord (filter (minimal_class classes S) S);
-
-
-
-(** intersection -- preserving minimality **)
+(* intersection *)
 
 fun inter_class classes c S =
   let
@@ -148,103 +135,17 @@
   sort_strings (fold (inter_class classes) S1 S2);
 
 
-
-(** sorts of types **)
-
-(* mg_domain *)
-
-exception DOMAIN of string * class;
+(* normal forms *)
 
-fun mg_domain (classes, arities) a S =
-  let
-    fun dom c =
-      (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
-        NONE => raise DOMAIN (a, c)
-      | SOME (_, Ss) => Ss);
-    fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
-  in
-    (case S of
-      [] => sys_error "mg_domain"  (*don't know number of args!*)
-    | c :: cs => fold dom_inter cs (dom c))
-  end;
-
-
-(* of_sort *)
-
-fun of_sort (classes, arities) =
-  let
-    fun ofS (_, []) = true
-      | ofS (TFree (_, S), S') = sort_le classes (S, S')
-      | ofS (TVar (_, S), S') = sort_le classes (S, S')
-      | ofS (Type (a, Ts), S) =
-          let val Ss = mg_domain (classes, arities) a S in
-            ListPair.all ofS (Ts, Ss)
-          end handle DOMAIN _ => false;
-  in ofS end;
+fun norm_sort _ [] = []
+  | norm_sort _ (S as [_]) = S
+  | norm_sort classes S =
+      filter (fn c => not (exists (fn c' => class_less classes (c', c)) S)) S
+      |> sort_distinct string_ord;
 
 
 
-(** witness_sorts **)
-
-local
-
-fun witness_aux (classes, arities) log_types hyps sorts =
-  let
-    val top_witn = (propT, []);
-    fun le S1 S2 = sort_le classes (S1, S2);
-    fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
-    fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
-    fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle DOMAIN _ => NONE;
-
-    fun witn_sort _ (solved_failed, []) = (solved_failed, SOME top_witn)
-      | witn_sort path ((solved, failed), S) =
-          if exists (le S) failed then ((solved, failed), NONE)
-          else
-            (case get_first (get_solved S) solved of
-              SOME w => ((solved, failed), SOME w)
-            | NONE =>
-                (case get_first (get_hyp S) hyps of
-                  SOME w => ((w :: solved, failed), SOME w)
-                | NONE => witn_types path log_types ((solved, failed), S)))
-
-    and witn_sorts path x = foldl_map (witn_sort path) x
-
-    and witn_types _ [] ((solved, failed), S) = ((solved, S :: failed), NONE)
-      | witn_types path (t :: ts) (solved_failed, S) =
-          (case mg_dom t S of
-            SOME SS =>
-              (*do not descend into stronger args (achieving termination)*)
-              if exists (fn D => le D S orelse exists (le D) path) SS then
-                witn_types path ts (solved_failed, S)
-              else
-                let val ((solved', failed'), ws) = witn_sorts (S :: path) (solved_failed, SS) in
-                  if forall is_some ws then
-                    let val w = (Type (t, map (#1 o the) ws), S)
-                    in ((w :: solved', failed'), SOME w) end
-                  else witn_types path ts ((solved', failed'), S)
-                end
-          | NONE => witn_types path ts (solved_failed, S));
-
-  in witn_sorts [] (([], []), sorts) end;
-
-fun str_of_sort [c] = c
-  | str_of_sort cs = enclose "{" "}" (commas cs);
-
-in
-
-fun witness_sorts (classes, arities) log_types hyps sorts =
-  let
-    fun double_check_result NONE = NONE
-      | double_check_result (SOME (T, S)) =
-          if of_sort (classes, arities) (T, S) then SOME (T, S)
-          else sys_error ("Sorts.witness_sorts: bad witness for sort " ^ str_of_sort S);
-  in map_filter double_check_result (#2 (witness_aux (classes, arities) log_types hyps sorts)) end;
-
-end;
-
-
-
-(** build sort algebras **)
+(** build algebras **)
 
 (* classes *)
 
@@ -326,8 +227,8 @@
     |> fold_rev (fold_rev (insert pp classes t)) (map (complete classes) ars)
   in Symtab.update (t, ars') arities end;
 
-fun add_arities_table pp classes = Symtab.fold (fn (t, ars) =>
-  add_arities pp classes (t, map (apsnd (map (norm_sort classes)) o snd) ars));
+fun add_arities_table pp classes =
+  Symtab.fold (fn (t, ars) => add_arities pp classes (t, map snd ars));
 
 fun rebuild_arities pp classes arities =
   Symtab.empty
@@ -340,4 +241,129 @@
 
 end;
 
+
+
+(** sorts of types **)
+
+(* mg_domain *)
+
+exception DOMAIN of string * class;
+
+fun domain_error pp (a, c) =
+  error ("No way to get " ^ Pretty.string_of_arity pp (a, [], [c]));
+
+fun mg_domain (classes, arities) a S =
+  let
+    fun dom c =
+      (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
+        NONE => raise DOMAIN (a, c)
+      | SOME (_, Ss) => Ss);
+    fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
+  in
+    (case S of
+      [] => raise Fail "Unknown domain of empty intersection"
+    | c :: cs => fold dom_inter cs (dom c))
+  end;
+
+
+(* of_sort *)
+
+fun of_sort (classes, arities) =
+  let
+    fun ofS (_, []) = true
+      | ofS (TFree (_, S), S') = sort_le classes (S, S')
+      | ofS (TVar (_, S), S') = sort_le classes (S, S')
+      | ofS (Type (a, Ts), S) =
+          let val Ss = mg_domain (classes, arities) a S in
+            ListPair.all ofS (Ts, Ss)
+          end handle DOMAIN _ => false;
+  in ofS end;
+
+
+(* of_sort_derivation *)
+
+fun of_sort_derivation pp (classes, arities) {classrel, constructor, variable} =
+  let
+    fun weaken (x, c1) c2 = if c1 = c2 then x else classrel (x, c1) c2;
+    fun weakens S1 S2 = S2 |> map (fn c2 =>
+      (case S1 |> find_first (fn (_, c1) => class_le classes (c1, c2)) of
+        SOME d1 => weaken d1 c2
+      | NONE => error ("Cannot derive subsort relation " ^
+          Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
+
+    fun derive _ [] = []
+      | derive (Type (a, Ts)) S =
+          let
+            val Ss = mg_domain (classes, arities) a S
+              handle DOMAIN d => domain_error pp d;
+            val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
+          in
+            S |> map (fn c =>
+              let
+                val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
+                val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
+              in weaken (constructor a dom' c0, c0) c end)
+          end
+      | derive T S = weakens (variable T) S;
+  in uncurry derive end;
+
+
+(* witness_sorts *)
+
+local
+
+fun witness_aux (classes, arities) log_types hyps sorts =
+  let
+    val top_witn = (propT, []);
+    fun le S1 S2 = sort_le classes (S1, S2);
+    fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
+    fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
+    fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle DOMAIN _ => NONE;
+
+    fun witn_sort _ (solved_failed, []) = (solved_failed, SOME top_witn)
+      | witn_sort path ((solved, failed), S) =
+          if exists (le S) failed then ((solved, failed), NONE)
+          else
+            (case get_first (get_solved S) solved of
+              SOME w => ((solved, failed), SOME w)
+            | NONE =>
+                (case get_first (get_hyp S) hyps of
+                  SOME w => ((w :: solved, failed), SOME w)
+                | NONE => witn_types path log_types ((solved, failed), S)))
+
+    and witn_sorts path x = foldl_map (witn_sort path) x
+
+    and witn_types _ [] ((solved, failed), S) = ((solved, S :: failed), NONE)
+      | witn_types path (t :: ts) (solved_failed, S) =
+          (case mg_dom t S of
+            SOME SS =>
+              (*do not descend into stronger args (achieving termination)*)
+              if exists (fn D => le D S orelse exists (le D) path) SS then
+                witn_types path ts (solved_failed, S)
+              else
+                let val ((solved', failed'), ws) = witn_sorts (S :: path) (solved_failed, SS) in
+                  if forall is_some ws then
+                    let val w = (Type (t, map (#1 o the) ws), S)
+                    in ((w :: solved', failed'), SOME w) end
+                  else witn_types path ts ((solved', failed'), S)
+                end
+          | NONE => witn_types path ts (solved_failed, S));
+
+  in witn_sorts [] (([], []), sorts) end;
+
+fun str_of_sort [c] = c
+  | str_of_sort cs = enclose "{" "}" (commas cs);
+
+in
+
+fun witness_sorts (classes, arities) log_types hyps sorts =
+  let
+    fun double_check_result NONE = NONE
+      | double_check_result (SOME (T, S)) =
+          if of_sort (classes, arities) (T, S) then SOME (T, S)
+          else sys_error ("Sorts.witness_sorts: bad witness for sort " ^ str_of_sort S);
+  in map_filter double_check_result (#2 (witness_aux (classes, arities) log_types hyps sorts)) end;
+
 end;
+
+end;