src/Pure/sorts.ML
author wenzelm
Tue May 02 00:20:38 2006 +0200 (2006-05-02)
changeset 19529 690861f93d2b
parent 19524 f795c1164708
child 19531 89970e06351f
permissions -rw-r--r--
added domain_error;
added of_sort_derivation;
tuned;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val eq_set: sort list * sort list -> bool
    19   val union: sort list -> sort list -> sort list
    20   val subtract: sort list -> sort list -> sort list
    21   val remove_sort: sort -> sort list -> sort list
    22   val insert_sort: sort -> sort list -> sort list
    23   val insert_typ: typ -> sort list -> sort list
    24   val insert_typs: typ list -> sort list -> sort list
    25   val insert_term: term -> sort list -> sort list
    26   val insert_terms: term list -> sort list -> sort list
    27   type classes
    28   type arities
    29   val class_eq: classes -> class * class -> bool
    30   val class_less: classes -> class * class -> bool
    31   val class_le: classes -> class * class -> bool
    32   val sort_eq: classes -> sort * sort -> bool
    33   val sort_le: classes -> sort * sort -> bool
    34   val sorts_le: classes -> sort list * sort list -> bool
    35   val inter_sort: classes -> sort * sort -> sort
    36   val norm_sort: classes -> sort -> sort
    37   val add_arities: Pretty.pp -> classes -> string * (class * sort list) list -> arities -> arities
    38   val rebuild_arities: Pretty.pp -> classes -> arities -> arities
    39   val merge_arities: Pretty.pp -> classes -> arities * arities -> arities
    40   val add_class: Pretty.pp -> class * class list -> classes -> classes
    41   val add_classrel: Pretty.pp -> class * class -> classes -> classes
    42   val merge_classes: Pretty.pp -> classes * classes -> classes
    43   exception DOMAIN of string * class
    44   val domain_error: Pretty.pp -> string * class -> 'a
    45   val mg_domain: classes * arities -> string -> sort -> sort list  (*exception DOMAIN*)
    46   val of_sort: classes * arities -> typ * sort -> bool
    47   val of_sort_derivation: Pretty.pp -> classes * arities ->
    48     {classrel: 'a * class -> class -> 'a,
    49      constructor: string -> ('a * class) list list -> class -> 'a,
    50      variable: typ -> ('a * class) list} -> typ * sort -> 'a list
    51   val witness_sorts: classes * arities -> string list ->
    52     sort list -> sort list -> (typ * sort) list
    53 end;
    54 
    55 structure Sorts: SORTS =
    56 struct
    57 
    58 
    59 (** ordered lists of sorts **)
    60 
    61 val eq_set = OrdList.eq_set Term.sort_ord;
    62 val op union = OrdList.union Term.sort_ord;
    63 val subtract = OrdList.subtract Term.sort_ord;
    64 
    65 val remove_sort = OrdList.remove Term.sort_ord;
    66 val insert_sort = OrdList.insert Term.sort_ord;
    67 
    68 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    69   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    70   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    71 and insert_typs [] Ss = Ss
    72   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    73 
    74 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    75   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    76   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    77   | insert_term (Bound _) Ss = Ss
    78   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    79   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    80 
    81 fun insert_terms [] Ss = Ss
    82   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    83 
    84 
    85 
    86 (** order-sorted algebra **)
    87 
    88 (*
    89   classes: graph representing class declarations together with proper
    90     subclass relation, which needs to be transitive and acyclic.
    91 
    92   arities: table of association lists of all type arities; (t, ars)
    93     means that type constructor t has the arities ars; an element (c,
    94     (c0, Ss)) of ars represents the arity t::(Ss)c being derived via
    95     c0 < c.  "Coregularity" of the arities structure requires that for
    96     any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that c1 <= c2
    97     holds Ss1 <= Ss2.
    98 *)
    99 
   100 type classes = stamp Graph.T;
   101 type arities = (class * (class * sort list)) list Symtab.table;
   102 
   103 
   104 (* class relations *)
   105 
   106 fun class_eq (_: classes) (c1, c2:class) = c1 = c2;
   107 val class_less: classes -> class * class -> bool = Graph.is_edge;
   108 fun class_le classes (c1, c2) = c1 = c2 orelse class_less classes (c1, c2);
   109 
   110 
   111 (* sort relations *)
   112 
   113 fun sort_le classes (S1, S2) =
   114   forall (fn c2 => exists (fn c1 => class_le classes (c1, c2)) S1) S2;
   115 
   116 fun sorts_le classes (Ss1, Ss2) =
   117   ListPair.all (sort_le classes) (Ss1, Ss2);
   118 
   119 fun sort_eq classes (S1, S2) =
   120   sort_le classes (S1, S2) andalso sort_le classes (S2, S1);
   121 
   122 
   123 (* intersection *)
   124 
   125 fun inter_class classes c S =
   126   let
   127     fun intr [] = [c]
   128       | intr (S' as c' :: c's) =
   129           if class_le classes (c', c) then S'
   130           else if class_le classes (c, c') then intr c's
   131           else c' :: intr c's
   132   in intr S end;
   133 
   134 fun inter_sort classes (S1, S2) =
   135   sort_strings (fold (inter_class classes) S1 S2);
   136 
   137 
   138 (* normal forms *)
   139 
   140 fun norm_sort _ [] = []
   141   | norm_sort _ (S as [_]) = S
   142   | norm_sort classes S =
   143       filter (fn c => not (exists (fn c' => class_less classes (c', c)) S)) S
   144       |> sort_distinct string_ord;
   145 
   146 
   147 
   148 (** build algebras **)
   149 
   150 (* classes *)
   151 
   152 local
   153 
   154 fun err_dup_classes cs =
   155   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   156 
   157 fun err_cyclic_classes pp css =
   158   error (cat_lines (map (fn cs =>
   159     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   160 
   161 in
   162 
   163 fun add_class pp (c, cs) classes =
   164   let
   165     val classes' = classes |> Graph.new_node (c, stamp ())
   166       handle Graph.DUP dup => err_dup_classes [dup];
   167     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   168       handle Graph.CYCLES css => err_cyclic_classes pp css;
   169   in classes'' end;
   170 
   171 fun add_classrel pp rel classes =
   172   classes |> Graph.add_edge_trans_acyclic rel
   173     handle Graph.CYCLES css => err_cyclic_classes pp css;
   174 
   175 fun merge_classes pp args : classes =
   176   Graph.merge_trans_acyclic (op =) args
   177     handle Graph.DUPS cs => err_dup_classes cs
   178         | Graph.CYCLES css => err_cyclic_classes pp css;
   179 
   180 end;
   181 
   182 
   183 (* arities *)
   184 
   185 local
   186 
   187 fun for_classes _ NONE = ""
   188   | for_classes pp (SOME (c1, c2)) =
   189       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   190 
   191 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   192   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   193     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   194     Pretty.string_of_arity pp (t, Ss', [c']));
   195 
   196 fun coregular pp C t (c, (c0, Ss)) ars =
   197   let
   198     fun conflict (c', (_, Ss')) =
   199       if class_le C (c, c') andalso not (sorts_le C (Ss, Ss')) then
   200         SOME ((c, c'), (c', Ss'))
   201       else if class_le C (c', c) andalso not (sorts_le C (Ss', Ss)) then
   202         SOME ((c', c), (c', Ss'))
   203       else NONE;
   204   in
   205     (case get_first conflict ars of
   206       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   207     | NONE => (c, (c0, Ss)) :: ars)
   208   end;
   209 
   210 fun insert pp C t (c, (c0, Ss)) ars =
   211   (case AList.lookup (op =) ars c of
   212     NONE => coregular pp C t (c, (c0, Ss)) ars
   213   | SOME (_, Ss') =>
   214       if sorts_le C (Ss, Ss') then ars
   215       else if sorts_le C (Ss', Ss) then
   216         coregular pp C t (c, (c0, Ss))
   217           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   218       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   219 
   220 fun complete C (c0, Ss) = map (rpair (c0, Ss)) (Graph.all_succs C [c0]);
   221 
   222 in
   223 
   224 fun add_arities pp classes (t, ars) arities =
   225   let val ars' =
   226     Symtab.lookup_list arities t
   227     |> fold_rev (fold_rev (insert pp classes t)) (map (complete classes) ars)
   228   in Symtab.update (t, ars') arities end;
   229 
   230 fun add_arities_table pp classes =
   231   Symtab.fold (fn (t, ars) => add_arities pp classes (t, map snd ars));
   232 
   233 fun rebuild_arities pp classes arities =
   234   Symtab.empty
   235   |> add_arities_table pp classes arities;
   236 
   237 fun merge_arities pp classes (arities1, arities2) =
   238   Symtab.empty
   239   |> add_arities_table pp classes arities1
   240   |> add_arities_table pp classes arities2;
   241 
   242 end;
   243 
   244 
   245 
   246 (** sorts of types **)
   247 
   248 (* mg_domain *)
   249 
   250 exception DOMAIN of string * class;
   251 
   252 fun domain_error pp (a, c) =
   253   error ("No way to get " ^ Pretty.string_of_arity pp (a, [], [c]));
   254 
   255 fun mg_domain (classes, arities) a S =
   256   let
   257     fun dom c =
   258       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   259         NONE => raise DOMAIN (a, c)
   260       | SOME (_, Ss) => Ss);
   261     fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
   262   in
   263     (case S of
   264       [] => raise Fail "Unknown domain of empty intersection"
   265     | c :: cs => fold dom_inter cs (dom c))
   266   end;
   267 
   268 
   269 (* of_sort *)
   270 
   271 fun of_sort (classes, arities) =
   272   let
   273     fun ofS (_, []) = true
   274       | ofS (TFree (_, S), S') = sort_le classes (S, S')
   275       | ofS (TVar (_, S), S') = sort_le classes (S, S')
   276       | ofS (Type (a, Ts), S) =
   277           let val Ss = mg_domain (classes, arities) a S in
   278             ListPair.all ofS (Ts, Ss)
   279           end handle DOMAIN _ => false;
   280   in ofS end;
   281 
   282 
   283 (* of_sort_derivation *)
   284 
   285 fun of_sort_derivation pp (classes, arities) {classrel, constructor, variable} =
   286   let
   287     fun weaken (x, c1) c2 = if c1 = c2 then x else classrel (x, c1) c2;
   288     fun weakens S1 S2 = S2 |> map (fn c2 =>
   289       (case S1 |> find_first (fn (_, c1) => class_le classes (c1, c2)) of
   290         SOME d1 => weaken d1 c2
   291       | NONE => error ("Cannot derive subsort relation " ^
   292           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   293 
   294     fun derive _ [] = []
   295       | derive (Type (a, Ts)) S =
   296           let
   297             val Ss = mg_domain (classes, arities) a S
   298               handle DOMAIN d => domain_error pp d;
   299             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   300           in
   301             S |> map (fn c =>
   302               let
   303                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   304                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   305               in weaken (constructor a dom' c0, c0) c end)
   306           end
   307       | derive T S = weakens (variable T) S;
   308   in uncurry derive end;
   309 
   310 
   311 (* witness_sorts *)
   312 
   313 local
   314 
   315 fun witness_aux (classes, arities) log_types hyps sorts =
   316   let
   317     val top_witn = (propT, []);
   318     fun le S1 S2 = sort_le classes (S1, S2);
   319     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   320     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   321     fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle DOMAIN _ => NONE;
   322 
   323     fun witn_sort _ (solved_failed, []) = (solved_failed, SOME top_witn)
   324       | witn_sort path ((solved, failed), S) =
   325           if exists (le S) failed then ((solved, failed), NONE)
   326           else
   327             (case get_first (get_solved S) solved of
   328               SOME w => ((solved, failed), SOME w)
   329             | NONE =>
   330                 (case get_first (get_hyp S) hyps of
   331                   SOME w => ((w :: solved, failed), SOME w)
   332                 | NONE => witn_types path log_types ((solved, failed), S)))
   333 
   334     and witn_sorts path x = foldl_map (witn_sort path) x
   335 
   336     and witn_types _ [] ((solved, failed), S) = ((solved, S :: failed), NONE)
   337       | witn_types path (t :: ts) (solved_failed, S) =
   338           (case mg_dom t S of
   339             SOME SS =>
   340               (*do not descend into stronger args (achieving termination)*)
   341               if exists (fn D => le D S orelse exists (le D) path) SS then
   342                 witn_types path ts (solved_failed, S)
   343               else
   344                 let val ((solved', failed'), ws) = witn_sorts (S :: path) (solved_failed, SS) in
   345                   if forall is_some ws then
   346                     let val w = (Type (t, map (#1 o the) ws), S)
   347                     in ((w :: solved', failed'), SOME w) end
   348                   else witn_types path ts ((solved', failed'), S)
   349                 end
   350           | NONE => witn_types path ts (solved_failed, S));
   351 
   352   in witn_sorts [] (([], []), sorts) end;
   353 
   354 fun str_of_sort [c] = c
   355   | str_of_sort cs = enclose "{" "}" (commas cs);
   356 
   357 in
   358 
   359 fun witness_sorts (classes, arities) log_types hyps sorts =
   360   let
   361     fun double_check_result NONE = NONE
   362       | double_check_result (SOME (T, S)) =
   363           if of_sort (classes, arities) (T, S) then SOME (T, S)
   364           else sys_error ("Sorts.witness_sorts: bad witness for sort " ^ str_of_sort S);
   365   in map_filter double_check_result (#2 (witness_aux (classes, arities) log_types hyps sorts)) end;
   366 
   367 end;
   368 
   369 end;