--- a/src/Pure/sorts.ML Fri May 05 21:59:45 2006 +0200
+++ b/src/Pure/sorts.ML Fri May 05 21:59:46 2006 +0200
@@ -40,14 +40,15 @@
val add_class: Pretty.pp -> class * class list -> classes -> classes
val add_classrel: Pretty.pp -> class * class -> classes -> classes
val merge_classes: Pretty.pp -> classes * classes -> classes
- exception DOMAIN of string * class
- val domain_error: Pretty.pp -> string * class -> 'a
- val mg_domain: classes * arities -> string -> sort -> sort list (*exception DOMAIN*)
+ type class_error
+ val class_error: Pretty.pp -> class_error -> 'a
+ exception CLASS_ERROR of class_error
+ val mg_domain: classes * arities -> string -> sort -> sort list (*exception CLASS_ERROR*)
val of_sort: classes * arities -> typ * sort -> bool
val of_sort_derivation: Pretty.pp -> classes * arities ->
{classrel: 'a * class -> class -> 'a,
constructor: string -> ('a * class) list list -> class -> 'a,
- variable: typ -> ('a * class) list} -> typ * sort -> 'a list
+ variable: typ -> ('a * class) list} -> typ * sort -> 'a list (*exception CLASS_ERROR*)
val witness_sorts: classes * arities -> string list ->
sort list -> sort list -> (typ * sort) list
end;
@@ -245,18 +246,25 @@
(** sorts of types **)
-(* mg_domain *)
+(* errors *)
+
+datatype class_error = NoClassrel of class * class | NoArity of string * class;
-exception DOMAIN of string * class;
+fun class_error pp (NoClassrel (c1, c2)) =
+ error ("No class relation " ^ Pretty.string_of_classrel pp [c1, c2])
+ | class_error pp (NoArity (a, c)) =
+ error ("No type arity " ^ Pretty.string_of_arity pp (a, [], [c]));
-fun domain_error pp (a, c) =
- error ("No way to get " ^ Pretty.string_of_arity pp (a, [], [c]));
+exception CLASS_ERROR of class_error;
+
+
+(* mg_domain *)
fun mg_domain (classes, arities) a S =
let
fun dom c =
(case AList.lookup (op =) (Symtab.lookup_list arities a) c of
- NONE => raise DOMAIN (a, c)
+ NONE => raise CLASS_ERROR (NoArity (a, c))
| SOME (_, Ss) => Ss);
fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
in
@@ -276,7 +284,7 @@
| ofS (Type (a, Ts), S) =
let val Ss = mg_domain (classes, arities) a S in
ListPair.all ofS (Ts, Ss)
- end handle DOMAIN _ => false;
+ end handle CLASS_ERROR _ => false;
in ofS end;
@@ -284,7 +292,13 @@
fun of_sort_derivation pp (classes, arities) {classrel, constructor, variable} =
let
- fun weaken (x, c1) c2 = if c1 = c2 then x else classrel (x, c1) c2;
+ fun weaken_path (x, c1 :: c2 :: cs) = weaken_path (classrel (x, c1) c2, c2 :: cs)
+ | weaken_path (x, _) = x;
+ fun weaken (x, c1) c2 =
+ (case Graph.irreducible_paths classes (c1, c2) of
+ [] => raise CLASS_ERROR (NoClassrel (c1, c2))
+ | cs :: _ => weaken_path (x, cs));
+
fun weakens S1 S2 = S2 |> map (fn c2 =>
(case S1 |> find_first (fn (_, c1) => class_le classes (c1, c2)) of
SOME d1 => weaken d1 c2
@@ -294,8 +308,7 @@
fun derive _ [] = []
| derive (Type (a, Ts)) S =
let
- val Ss = mg_domain (classes, arities) a S
- handle DOMAIN d => domain_error pp d;
+ val Ss = mg_domain (classes, arities) a S;
val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
in
S |> map (fn c =>
@@ -310,60 +323,46 @@
(* witness_sorts *)
-local
-
-fun witness_aux (classes, arities) log_types hyps sorts =
+fun witness_sorts (classes, arities) log_types hyps sorts =
let
- val top_witn = (propT, []);
fun le S1 S2 = sort_le classes (S1, S2);
fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
- fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle DOMAIN _ => NONE;
+ fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle CLASS_ERROR _ => NONE;
- fun witn_sort _ (solved_failed, []) = (solved_failed, SOME top_witn)
- | witn_sort path ((solved, failed), S) =
- if exists (le S) failed then ((solved, failed), NONE)
+ fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
+ | witn_sort path S (solved, failed) =
+ if exists (le S) failed then (NONE, (solved, failed))
else
(case get_first (get_solved S) solved of
- SOME w => ((solved, failed), SOME w)
+ SOME w => (SOME w, (solved, failed))
| NONE =>
(case get_first (get_hyp S) hyps of
- SOME w => ((w :: solved, failed), SOME w)
- | NONE => witn_types path log_types ((solved, failed), S)))
+ SOME w => (SOME w, (w :: solved, failed))
+ | NONE => witn_types path log_types S (solved, failed)))
- and witn_sorts path x = foldl_map (witn_sort path) x
+ and witn_sorts path x = fold_map (witn_sort path) x
- and witn_types _ [] ((solved, failed), S) = ((solved, S :: failed), NONE)
- | witn_types path (t :: ts) (solved_failed, S) =
+ and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
+ | witn_types path (t :: ts) S solved_failed =
(case mg_dom t S of
SOME SS =>
(*do not descend into stronger args (achieving termination)*)
if exists (fn D => le D S orelse exists (le D) path) SS then
- witn_types path ts (solved_failed, S)
+ witn_types path ts S solved_failed
else
- let val ((solved', failed'), ws) = witn_sorts (S :: path) (solved_failed, SS) in
+ let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
if forall is_some ws then
let val w = (Type (t, map (#1 o the) ws), S)
- in ((w :: solved', failed'), SOME w) end
- else witn_types path ts ((solved', failed'), S)
+ in (SOME w, (w :: solved', failed')) end
+ else witn_types path ts S (solved', failed')
end
- | NONE => witn_types path ts (solved_failed, S));
-
- in witn_sorts [] (([], []), sorts) end;
-
-fun str_of_sort [c] = c
- | str_of_sort cs = enclose "{" "}" (commas cs);
+ | NONE => witn_types path ts S solved_failed);
-in
+ fun double_check TS =
+ if of_sort (classes, arities) TS then TS
+ else sys_error "FIXME Bad sort witness";
-fun witness_sorts (classes, arities) log_types hyps sorts =
- let
- fun double_check_result NONE = NONE
- | double_check_result (SOME (T, S)) =
- if of_sort (classes, arities) (T, S) then SOME (T, S)
- else sys_error ("Sorts.witness_sorts: bad witness for sort " ^ str_of_sort S);
- in map_filter double_check_result (#2 (witness_aux (classes, arities) log_types hyps sorts)) end;
+ in map_filter (Option.map double_check) (#1 (witn_sorts [] sorts ([], []))) end;
end;
-
-end;