src/Pure/sorts.ML
author wenzelm
Mon May 01 17:05:12 2006 +0200 (2006-05-01)
changeset 19524 f795c1164708
parent 19514 1f0218dab849
child 19529 690861f93d2b
permissions -rw-r--r--
arities: maintain original codomain;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 *)
     7 
     8 signature SORTS =
     9 sig
    10   val eq_set: sort list * sort list -> bool
    11   val union: sort list -> sort list -> sort list
    12   val subtract: sort list -> sort list -> sort list
    13   val remove_sort: sort -> sort list -> sort list
    14   val insert_sort: sort -> sort list -> sort list
    15   val insert_typ: typ -> sort list -> sort list
    16   val insert_typs: typ list -> sort list -> sort list
    17   val insert_term: term -> sort list -> sort list
    18   val insert_terms: term list -> sort list -> sort list
    19   type classes
    20   type arities
    21   val class_eq: classes -> class * class -> bool
    22   val class_less: classes -> class * class -> bool
    23   val class_le: classes -> class * class -> bool
    24   val sort_eq: classes -> sort * sort -> bool
    25   val sort_le: classes -> sort * sort -> bool
    26   val sorts_le: classes -> sort list * sort list -> bool
    27   val inter_sort: classes -> sort * sort -> sort
    28   val norm_sort: classes -> sort -> sort
    29   val of_sort: classes * arities -> typ * sort -> bool
    30   exception DOMAIN of string * class
    31   val mg_domain: classes * arities -> string -> sort -> sort list  (*exception DOMAIN*)
    32   val witness_sorts: classes * arities -> string list ->
    33     sort list -> sort list -> (typ * sort) list
    34   val add_arities: Pretty.pp -> classes -> string * (class * sort list) list -> arities -> arities
    35   val rebuild_arities: Pretty.pp -> classes -> arities -> arities
    36   val merge_arities: Pretty.pp -> classes -> arities * arities -> arities
    37   val add_class: Pretty.pp -> class * class list -> classes -> classes
    38   val add_classrel: Pretty.pp -> class * class -> classes -> classes
    39   val merge_classes: Pretty.pp -> classes * classes -> classes
    40 end;
    41 
    42 structure Sorts: SORTS =
    43 struct
    44 
    45 
    46 (** type classes and sorts **)
    47 
    48 (*
    49   Classes denote (possibly empty) collections of types that are
    50   partially ordered by class inclusion. They are represented
    51   symbolically by strings.
    52 
    53   Sorts are intersections of finitely many classes. They are
    54   represented by lists of classes.  Normal forms of sorts are sorted
    55   lists of minimal classes (wrt. current class inclusion).
    56 *)
    57 
    58 
    59 (* ordered lists of sorts *)
    60 
    61 val eq_set = OrdList.eq_set Term.sort_ord;
    62 val op union = OrdList.union Term.sort_ord;
    63 val subtract = OrdList.subtract Term.sort_ord;
    64 
    65 val remove_sort = OrdList.remove Term.sort_ord;
    66 val insert_sort = OrdList.insert Term.sort_ord;
    67 
    68 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    69   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    70   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    71 and insert_typs [] Ss = Ss
    72   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    73 
    74 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    75   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    76   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    77   | insert_term (Bound _) Ss = Ss
    78   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    79   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    80 
    81 fun insert_terms [] Ss = Ss
    82   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    83 
    84 
    85 (* order-sorted algebra *)
    86 
    87 (*
    88   classes: graph representing class declarations together with proper
    89     subclass relation, which needs to be transitive and acyclic.
    90 
    91   arities: table of association lists of all type arities; (t, ars)
    92     means that type constructor t has the arities ars; an element (c,
    93     (c0, Ss)) of ars represents the arity t::(Ss)c being derived via
    94     c0 < c.  "Coregularity" of the arities structure requires that for
    95     any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2
    96     holds Ss1 <= Ss2.
    97 *)
    98 
    99 type classes = stamp Graph.T;
   100 type arities = (class * (class * sort list)) list Symtab.table;
   101 
   102 
   103 
   104 (** equality and inclusion **)
   105 
   106 (* classes *)
   107 
   108 fun class_eq (_: classes) (c1, c2:class) = c1 = c2;
   109 val class_less: classes -> class * class -> bool = Graph.is_edge;
   110 fun class_le classes (c1, c2) = c1 = c2 orelse class_less classes (c1, c2);
   111 
   112 
   113 (* sorts *)
   114 
   115 fun sort_le classes (S1, S2) =
   116   forall (fn c2 => exists (fn c1 => class_le classes (c1, c2)) S1) S2;
   117 
   118 fun sorts_le classes (Ss1, Ss2) =
   119   ListPair.all (sort_le classes) (Ss1, Ss2);
   120 
   121 fun sort_eq classes (S1, S2) =
   122   sort_le classes (S1, S2) andalso sort_le classes (S2, S1);
   123 
   124 
   125 (* normal forms of sorts *)
   126 
   127 fun minimal_class classes S c =
   128   not (exists (fn c' => class_less classes (c', c)) S);
   129 
   130 fun norm_sort _ [] = []
   131   | norm_sort _ (S as [_]) = S
   132   | norm_sort classes S = sort_distinct string_ord (filter (minimal_class classes S) S);
   133 
   134 
   135 
   136 (** intersection -- preserving minimality **)
   137 
   138 fun inter_class classes c S =
   139   let
   140     fun intr [] = [c]
   141       | intr (S' as c' :: c's) =
   142           if class_le classes (c', c) then S'
   143           else if class_le classes (c, c') then intr c's
   144           else c' :: intr c's
   145   in intr S end;
   146 
   147 fun inter_sort classes (S1, S2) =
   148   sort_strings (fold (inter_class classes) S1 S2);
   149 
   150 
   151 
   152 (** sorts of types **)
   153 
   154 (* mg_domain *)
   155 
   156 exception DOMAIN of string * class;
   157 
   158 fun mg_domain (classes, arities) a S =
   159   let
   160     fun dom c =
   161       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   162         NONE => raise DOMAIN (a, c)
   163       | SOME (_, Ss) => Ss);
   164     fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
   165   in
   166     (case S of
   167       [] => sys_error "mg_domain"  (*don't know number of args!*)
   168     | c :: cs => fold dom_inter cs (dom c))
   169   end;
   170 
   171 
   172 (* of_sort *)
   173 
   174 fun of_sort (classes, arities) =
   175   let
   176     fun ofS (_, []) = true
   177       | ofS (TFree (_, S), S') = sort_le classes (S, S')
   178       | ofS (TVar (_, S), S') = sort_le classes (S, S')
   179       | ofS (Type (a, Ts), S) =
   180           let val Ss = mg_domain (classes, arities) a S in
   181             ListPair.all ofS (Ts, Ss)
   182           end handle DOMAIN _ => false;
   183   in ofS end;
   184 
   185 
   186 
   187 (** witness_sorts **)
   188 
   189 local
   190 
   191 fun witness_aux (classes, arities) log_types hyps sorts =
   192   let
   193     val top_witn = (propT, []);
   194     fun le S1 S2 = sort_le classes (S1, S2);
   195     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   196     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   197     fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle DOMAIN _ => NONE;
   198 
   199     fun witn_sort _ (solved_failed, []) = (solved_failed, SOME top_witn)
   200       | witn_sort path ((solved, failed), S) =
   201           if exists (le S) failed then ((solved, failed), NONE)
   202           else
   203             (case get_first (get_solved S) solved of
   204               SOME w => ((solved, failed), SOME w)
   205             | NONE =>
   206                 (case get_first (get_hyp S) hyps of
   207                   SOME w => ((w :: solved, failed), SOME w)
   208                 | NONE => witn_types path log_types ((solved, failed), S)))
   209 
   210     and witn_sorts path x = foldl_map (witn_sort path) x
   211 
   212     and witn_types _ [] ((solved, failed), S) = ((solved, S :: failed), NONE)
   213       | witn_types path (t :: ts) (solved_failed, S) =
   214           (case mg_dom t S of
   215             SOME SS =>
   216               (*do not descend into stronger args (achieving termination)*)
   217               if exists (fn D => le D S orelse exists (le D) path) SS then
   218                 witn_types path ts (solved_failed, S)
   219               else
   220                 let val ((solved', failed'), ws) = witn_sorts (S :: path) (solved_failed, SS) in
   221                   if forall is_some ws then
   222                     let val w = (Type (t, map (#1 o the) ws), S)
   223                     in ((w :: solved', failed'), SOME w) end
   224                   else witn_types path ts ((solved', failed'), S)
   225                 end
   226           | NONE => witn_types path ts (solved_failed, S));
   227 
   228   in witn_sorts [] (([], []), sorts) end;
   229 
   230 fun str_of_sort [c] = c
   231   | str_of_sort cs = enclose "{" "}" (commas cs);
   232 
   233 in
   234 
   235 fun witness_sorts (classes, arities) log_types hyps sorts =
   236   let
   237     fun double_check_result NONE = NONE
   238       | double_check_result (SOME (T, S)) =
   239           if of_sort (classes, arities) (T, S) then SOME (T, S)
   240           else sys_error ("Sorts.witness_sorts: bad witness for sort " ^ str_of_sort S);
   241   in map_filter double_check_result (#2 (witness_aux (classes, arities) log_types hyps sorts)) end;
   242 
   243 end;
   244 
   245 
   246 
   247 (** build sort algebras **)
   248 
   249 (* classes *)
   250 
   251 local
   252 
   253 fun err_dup_classes cs =
   254   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   255 
   256 fun err_cyclic_classes pp css =
   257   error (cat_lines (map (fn cs =>
   258     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   259 
   260 in
   261 
   262 fun add_class pp (c, cs) classes =
   263   let
   264     val classes' = classes |> Graph.new_node (c, stamp ())
   265       handle Graph.DUP dup => err_dup_classes [dup];
   266     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   267       handle Graph.CYCLES css => err_cyclic_classes pp css;
   268   in classes'' end;
   269 
   270 fun add_classrel pp rel classes =
   271   classes |> Graph.add_edge_trans_acyclic rel
   272     handle Graph.CYCLES css => err_cyclic_classes pp css;
   273 
   274 fun merge_classes pp args : classes =
   275   Graph.merge_trans_acyclic (op =) args
   276     handle Graph.DUPS cs => err_dup_classes cs
   277         | Graph.CYCLES css => err_cyclic_classes pp css;
   278 
   279 end;
   280 
   281 
   282 (* arities *)
   283 
   284 local
   285 
   286 fun for_classes _ NONE = ""
   287   | for_classes pp (SOME (c1, c2)) =
   288       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   289 
   290 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   291   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   292     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   293     Pretty.string_of_arity pp (t, Ss', [c']));
   294 
   295 fun coregular pp C t (c, (c0, Ss)) ars =
   296   let
   297     fun conflict (c', (_, Ss')) =
   298       if class_le C (c, c') andalso not (sorts_le C (Ss, Ss')) then
   299         SOME ((c, c'), (c', Ss'))
   300       else if class_le C (c', c) andalso not (sorts_le C (Ss', Ss)) then
   301         SOME ((c', c), (c', Ss'))
   302       else NONE;
   303   in
   304     (case get_first conflict ars of
   305       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   306     | NONE => (c, (c0, Ss)) :: ars)
   307   end;
   308 
   309 fun insert pp C t (c, (c0, Ss)) ars =
   310   (case AList.lookup (op =) ars c of
   311     NONE => coregular pp C t (c, (c0, Ss)) ars
   312   | SOME (_, Ss') =>
   313       if sorts_le C (Ss, Ss') then ars
   314       else if sorts_le C (Ss', Ss) then
   315         coregular pp C t (c, (c0, Ss))
   316           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   317       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   318 
   319 fun complete C (c0, Ss) = map (rpair (c0, Ss)) (Graph.all_succs C [c0]);
   320 
   321 in
   322 
   323 fun add_arities pp classes (t, ars) arities =
   324   let val ars' =
   325     Symtab.lookup_list arities t
   326     |> fold_rev (fold_rev (insert pp classes t)) (map (complete classes) ars)
   327   in Symtab.update (t, ars') arities end;
   328 
   329 fun add_arities_table pp classes = Symtab.fold (fn (t, ars) =>
   330   add_arities pp classes (t, map (apsnd (map (norm_sort classes)) o snd) ars));
   331 
   332 fun rebuild_arities pp classes arities =
   333   Symtab.empty
   334   |> add_arities_table pp classes arities;
   335 
   336 fun merge_arities pp classes (arities1, arities2) =
   337   Symtab.empty
   338   |> add_arities_table pp classes arities1
   339   |> add_arities_table pp classes arities2;
   340 
   341 end;
   342 
   343 end;