src/Pure/sorts.ML
author wenzelm
Tue May 16 13:01:28 2006 +0200 (2006-05-16)
changeset 19645 bbda28f2d379
parent 19584 606d6a73e6d9
child 19952 eaf2c25654d3
permissions -rw-r--r--
abstract interfaces for type algebra;
tuned;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val eq_set: sort list * sort list -> bool
    19   val union: sort list -> sort list -> sort list
    20   val subtract: sort list -> sort list -> sort list
    21   val remove_sort: sort -> sort list -> sort list
    22   val insert_sort: sort -> sort list -> sort list
    23   val insert_typ: typ -> sort list -> sort list
    24   val insert_typs: typ list -> sort list -> sort list
    25   val insert_term: term -> sort list -> sort list
    26   val insert_terms: term list -> sort list -> sort list
    27   type algebra
    28   val rep_algebra: algebra ->
    29    {classes: stamp Graph.T,
    30     arities: (class * (class * sort list)) list Symtab.table}
    31   val classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val certify_class: algebra -> class -> class    (*exception TYPE*)
    40   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    41   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    42   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    43   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    46   type class_error
    47   val class_error: Pretty.pp -> class_error -> 'a
    48   exception CLASS_ERROR of class_error
    49   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    50   val of_sort: algebra -> typ * sort -> bool
    51   val of_sort_derivation: Pretty.pp -> algebra ->
    52     {classrel: 'a * class -> class -> 'a,
    53      constructor: string -> ('a * class) list list -> class -> 'a,
    54      variable: typ -> ('a * class) list} ->
    55     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    56   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    57 end;
    58 
    59 structure Sorts: SORTS =
    60 struct
    61 
    62 
    63 (** ordered lists of sorts **)
    64 
    65 val eq_set = OrdList.eq_set Term.sort_ord;
    66 val op union = OrdList.union Term.sort_ord;
    67 val subtract = OrdList.subtract Term.sort_ord;
    68 
    69 val remove_sort = OrdList.remove Term.sort_ord;
    70 val insert_sort = OrdList.insert Term.sort_ord;
    71 
    72 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    73   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    74   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    75 and insert_typs [] Ss = Ss
    76   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    77 
    78 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    79   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    80   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    81   | insert_term (Bound _) Ss = Ss
    82   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    83   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    84 
    85 fun insert_terms [] Ss = Ss
    86   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    87 
    88 
    89 
    90 (** order-sorted algebra **)
    91 
    92 (*
    93   classes: graph representing class declarations together with proper
    94     subclass relation, which needs to be transitive and acyclic.
    95 
    96   arities: table of association lists of all type arities; (t, ars)
    97     means that type constructor t has the arities ars; an element
    98     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
    99     via c0 <= c.  "Coregularity" of the arities structure requires
   100     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   101     c1 <= c2 holds Ss1 <= Ss2.
   102 *)
   103 
   104 datatype algebra = Algebra of
   105  {classes: stamp Graph.T,
   106   arities: (class * (class * sort list)) list Symtab.table};
   107 
   108 fun rep_algebra (Algebra args) = args;
   109 
   110 val classes_of = #classes o rep_algebra;
   111 val arities_of = #arities o rep_algebra;
   112 
   113 fun make_algebra (classes, arities) =
   114   Algebra {classes = classes, arities = arities};
   115 
   116 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   117 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   118 
   119 
   120 (* classes *)
   121 
   122 val classes = Graph.keys o classes_of;
   123 val super_classes = Graph.imm_succs o classes_of;
   124 
   125 
   126 (* class relations *)
   127 
   128 val class_less = Graph.is_edge o classes_of;
   129 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   130 
   131 
   132 (* sort relations *)
   133 
   134 fun sort_le algebra (S1, S2) =
   135   forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   136 
   137 fun sorts_le algebra (Ss1, Ss2) =
   138   ListPair.all (sort_le algebra) (Ss1, Ss2);
   139 
   140 fun sort_eq algebra (S1, S2) =
   141   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   142 
   143 
   144 (* intersection *)
   145 
   146 fun inter_class algebra c S =
   147   let
   148     fun intr [] = [c]
   149       | intr (S' as c' :: c's) =
   150           if class_le algebra (c', c) then S'
   151           else if class_le algebra (c, c') then intr c's
   152           else c' :: intr c's
   153   in intr S end;
   154 
   155 fun inter_sort algebra (S1, S2) =
   156   sort_strings (fold (inter_class algebra) S1 S2);
   157 
   158 
   159 (* normal form *)
   160 
   161 fun norm_sort _ [] = []
   162   | norm_sort _ (S as [_]) = S
   163   | norm_sort algebra S =
   164       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   165       |> sort_distinct string_ord;
   166 
   167 
   168 (* certify *)
   169 
   170 fun certify_class algebra c =
   171   if can (Graph.get_node (classes_of algebra)) c then c
   172   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   173 
   174 fun certify_sort classes = norm_sort classes o map (certify_class classes);
   175 
   176 
   177 
   178 (** build algebras **)
   179 
   180 (* classes *)
   181 
   182 fun err_dup_classes cs =
   183   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   184 
   185 fun err_cyclic_classes pp css =
   186   error (cat_lines (map (fn cs =>
   187     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   188 
   189 fun add_class pp (c, cs) = map_classes (fn classes =>
   190   let
   191     val classes' = classes |> Graph.new_node (c, stamp ())
   192       handle Graph.DUP dup => err_dup_classes [dup];
   193     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   194       handle Graph.CYCLES css => err_cyclic_classes pp css;
   195   in classes'' end);
   196 
   197 
   198 (* arities *)
   199 
   200 local
   201 
   202 fun for_classes _ NONE = ""
   203   | for_classes pp (SOME (c1, c2)) =
   204       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   205 
   206 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   207   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   208     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   209     Pretty.string_of_arity pp (t, Ss', [c']));
   210 
   211 fun coregular pp algebra t (c, (c0, Ss)) ars =
   212   let
   213     fun conflict (c', (_, Ss')) =
   214       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   215         SOME ((c, c'), (c', Ss'))
   216       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   217         SOME ((c', c), (c', Ss'))
   218       else NONE;
   219   in
   220     (case get_first conflict ars of
   221       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   222     | NONE => (c, (c0, Ss)) :: ars)
   223   end;
   224 
   225 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   226 
   227 fun insert pp algebra t (c, (c0, Ss)) ars =
   228   (case AList.lookup (op =) ars c of
   229     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   230   | SOME (_, Ss') =>
   231       if sorts_le algebra (Ss, Ss') then ars
   232       else if sorts_le algebra (Ss', Ss) then
   233         coregular pp algebra t (c, (c0, Ss))
   234           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   235       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   236 
   237 fun insert_ars pp algebra (t, ars) arities =
   238   let val ars' =
   239     Symtab.lookup_list arities t
   240     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   241   in Symtab.update (t, ars') arities end;
   242 
   243 in
   244 
   245 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   246 
   247 fun add_arities_table pp algebra =
   248   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   249 
   250 end;
   251 
   252 
   253 (* classrel *)
   254 
   255 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   256   Symtab.empty
   257   |> add_arities_table pp algebra arities);
   258 
   259 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   260   classes |> Graph.add_edge_trans_acyclic rel
   261     handle Graph.CYCLES css => err_cyclic_classes pp css);
   262 
   263 
   264 (* empty and merge *)
   265 
   266 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   267 
   268 fun merge_algebra pp
   269    (Algebra {classes = classes1, arities = arities1},
   270     Algebra {classes = classes2, arities = arities2}) =
   271   let
   272     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   273       handle Graph.DUPS cs => err_dup_classes cs
   274           | Graph.CYCLES css => err_cyclic_classes pp css;
   275     val algebra0 = make_algebra (classes', Symtab.empty);
   276     val arities' = Symtab.empty
   277       |> add_arities_table pp algebra0 arities1
   278       |> add_arities_table pp algebra0 arities2;
   279   in make_algebra (classes', arities') end;
   280 
   281 
   282 
   283 (** sorts of types **)
   284 
   285 (* errors *)
   286 
   287 datatype class_error = NoClassrel of class * class | NoArity of string * class;
   288 
   289 fun class_error pp (NoClassrel (c1, c2)) =
   290       error ("No class relation " ^ Pretty.string_of_classrel pp [c1, c2])
   291   | class_error pp (NoArity (a, c)) =
   292       error ("No type arity " ^ Pretty.string_of_arity pp (a, [], [c]));
   293 
   294 exception CLASS_ERROR of class_error;
   295 
   296 
   297 (* mg_domain *)
   298 
   299 fun mg_domain algebra a S =
   300   let
   301     val arities = arities_of algebra;
   302     fun dom c =
   303       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   304         NONE => raise CLASS_ERROR (NoArity (a, c))
   305       | SOME (_, Ss) => Ss);
   306     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   307   in
   308     (case S of
   309       [] => raise Fail "Unknown domain of empty intersection"
   310     | c :: cs => fold dom_inter cs (dom c))
   311   end;
   312 
   313 
   314 (* of_sort *)
   315 
   316 fun of_sort algebra =
   317   let
   318     fun ofS (_, []) = true
   319       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   320       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   321       | ofS (Type (a, Ts), S) =
   322           let val Ss = mg_domain algebra a S in
   323             ListPair.all ofS (Ts, Ss)
   324           end handle CLASS_ERROR _ => false;
   325   in ofS end;
   326 
   327 
   328 (* of_sort_derivation *)
   329 
   330 fun of_sort_derivation pp algebra {classrel, constructor, variable} =
   331   let
   332     val {classes, arities} = rep_algebra algebra;
   333     fun weaken_path (x, c1 :: c2 :: cs) = weaken_path (classrel (x, c1) c2, c2 :: cs)
   334       | weaken_path (x, _) = x;
   335     fun weaken (x, c1) c2 =
   336       (case Graph.irreducible_paths classes (c1, c2) of
   337         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   338       | cs :: _ => weaken_path (x, cs));
   339 
   340     fun weakens S1 S2 = S2 |> map (fn c2 =>
   341       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   342         SOME d1 => weaken d1 c2
   343       | NONE => error ("Cannot derive subsort relation " ^
   344           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   345 
   346     fun derive _ [] = []
   347       | derive (Type (a, Ts)) S =
   348           let
   349             val Ss = mg_domain algebra a S;
   350             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   351           in
   352             S |> map (fn c =>
   353               let
   354                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   355                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   356               in weaken (constructor a dom' c0, c0) c end)
   357           end
   358       | derive T S = weakens (variable T) S;
   359   in uncurry derive end;
   360 
   361 
   362 (* witness_sorts *)
   363 
   364 fun witness_sorts algebra types hyps sorts =
   365   let
   366     fun le S1 S2 = sort_le algebra (S1, S2);
   367     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   368     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   369     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   370 
   371     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   372       | witn_sort path S (solved, failed) =
   373           if exists (le S) failed then (NONE, (solved, failed))
   374           else
   375             (case get_first (get_solved S) solved of
   376               SOME w => (SOME w, (solved, failed))
   377             | NONE =>
   378                 (case get_first (get_hyp S) hyps of
   379                   SOME w => (SOME w, (w :: solved, failed))
   380                 | NONE => witn_types path types S (solved, failed)))
   381 
   382     and witn_sorts path x = fold_map (witn_sort path) x
   383 
   384     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   385       | witn_types path (t :: ts) S solved_failed =
   386           (case mg_dom t S of
   387             SOME SS =>
   388               (*do not descend into stronger args (achieving termination)*)
   389               if exists (fn D => le D S orelse exists (le D) path) SS then
   390                 witn_types path ts S solved_failed
   391               else
   392                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   393                   if forall is_some ws then
   394                     let val w = (Type (t, map (#1 o the) ws), S)
   395                     in (SOME w, (w :: solved', failed')) end
   396                   else witn_types path ts S (solved', failed')
   397                 end
   398           | NONE => witn_types path ts S solved_failed);
   399 
   400   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   401 
   402 end;