src/Pure/sorts.ML
author wenzelm
Fri May 05 21:59:46 2006 +0200 (2006-05-05)
changeset 19578 f93b7637a5e6
parent 19531 89970e06351f
child 19584 606d6a73e6d9
permissions -rw-r--r--
added class_error and exception CLASS_ERROR (supercedes DOMAIN);
clarified of_class_derivation;
tuned witness_sorts;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val eq_set: sort list * sort list -> bool
    19   val union: sort list -> sort list -> sort list
    20   val subtract: sort list -> sort list -> sort list
    21   val remove_sort: sort -> sort list -> sort list
    22   val insert_sort: sort -> sort list -> sort list
    23   val insert_typ: typ -> sort list -> sort list
    24   val insert_typs: typ list -> sort list -> sort list
    25   val insert_term: term -> sort list -> sort list
    26   val insert_terms: term list -> sort list -> sort list
    27   type classes
    28   type arities
    29   val class_eq: classes -> class * class -> bool
    30   val class_less: classes -> class * class -> bool
    31   val class_le: classes -> class * class -> bool
    32   val sort_eq: classes -> sort * sort -> bool
    33   val sort_le: classes -> sort * sort -> bool
    34   val sorts_le: classes -> sort list * sort list -> bool
    35   val inter_sort: classes -> sort * sort -> sort
    36   val norm_sort: classes -> sort -> sort
    37   val add_arities: Pretty.pp -> classes -> string * (class * sort list) list -> arities -> arities
    38   val rebuild_arities: Pretty.pp -> classes -> arities -> arities
    39   val merge_arities: Pretty.pp -> classes -> arities * arities -> arities
    40   val add_class: Pretty.pp -> class * class list -> classes -> classes
    41   val add_classrel: Pretty.pp -> class * class -> classes -> classes
    42   val merge_classes: Pretty.pp -> classes * classes -> classes
    43   type class_error
    44   val class_error: Pretty.pp -> class_error -> 'a
    45   exception CLASS_ERROR of class_error
    46   val mg_domain: classes * arities -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    47   val of_sort: classes * arities -> typ * sort -> bool
    48   val of_sort_derivation: Pretty.pp -> classes * arities ->
    49     {classrel: 'a * class -> class -> 'a,
    50      constructor: string -> ('a * class) list list -> class -> 'a,
    51      variable: typ -> ('a * class) list} -> typ * sort -> 'a list   (*exception CLASS_ERROR*)
    52   val witness_sorts: classes * arities -> string list ->
    53     sort list -> sort list -> (typ * sort) list
    54 end;
    55 
    56 structure Sorts: SORTS =
    57 struct
    58 
    59 
    60 (** ordered lists of sorts **)
    61 
    62 val eq_set = OrdList.eq_set Term.sort_ord;
    63 val op union = OrdList.union Term.sort_ord;
    64 val subtract = OrdList.subtract Term.sort_ord;
    65 
    66 val remove_sort = OrdList.remove Term.sort_ord;
    67 val insert_sort = OrdList.insert Term.sort_ord;
    68 
    69 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    70   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    71   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    72 and insert_typs [] Ss = Ss
    73   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    74 
    75 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    76   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    77   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    78   | insert_term (Bound _) Ss = Ss
    79   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    80   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    81 
    82 fun insert_terms [] Ss = Ss
    83   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    84 
    85 
    86 
    87 (** order-sorted algebra **)
    88 
    89 (*
    90   classes: graph representing class declarations together with proper
    91     subclass relation, which needs to be transitive and acyclic.
    92 
    93   arities: table of association lists of all type arities; (t, ars)
    94     means that type constructor t has the arities ars; an element
    95     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
    96     via c0 <= c.  "Coregularity" of the arities structure requires
    97     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
    98     c1 <= c2 holds Ss1 <= Ss2.
    99 *)
   100 
   101 type classes = stamp Graph.T;
   102 type arities = (class * (class * sort list)) list Symtab.table;
   103 
   104 
   105 (* class relations *)
   106 
   107 fun class_eq (_: classes) (c1, c2:class) = c1 = c2;
   108 val class_less: classes -> class * class -> bool = Graph.is_edge;
   109 fun class_le classes (c1, c2) = c1 = c2 orelse class_less classes (c1, c2);
   110 
   111 
   112 (* sort relations *)
   113 
   114 fun sort_le classes (S1, S2) =
   115   forall (fn c2 => exists (fn c1 => class_le classes (c1, c2)) S1) S2;
   116 
   117 fun sorts_le classes (Ss1, Ss2) =
   118   ListPair.all (sort_le classes) (Ss1, Ss2);
   119 
   120 fun sort_eq classes (S1, S2) =
   121   sort_le classes (S1, S2) andalso sort_le classes (S2, S1);
   122 
   123 
   124 (* intersection *)
   125 
   126 fun inter_class classes c S =
   127   let
   128     fun intr [] = [c]
   129       | intr (S' as c' :: c's) =
   130           if class_le classes (c', c) then S'
   131           else if class_le classes (c, c') then intr c's
   132           else c' :: intr c's
   133   in intr S end;
   134 
   135 fun inter_sort classes (S1, S2) =
   136   sort_strings (fold (inter_class classes) S1 S2);
   137 
   138 
   139 (* normal forms *)
   140 
   141 fun norm_sort _ [] = []
   142   | norm_sort _ (S as [_]) = S
   143   | norm_sort classes S =
   144       filter (fn c => not (exists (fn c' => class_less classes (c', c)) S)) S
   145       |> sort_distinct string_ord;
   146 
   147 
   148 
   149 (** build algebras **)
   150 
   151 (* classes *)
   152 
   153 local
   154 
   155 fun err_dup_classes cs =
   156   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   157 
   158 fun err_cyclic_classes pp css =
   159   error (cat_lines (map (fn cs =>
   160     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   161 
   162 in
   163 
   164 fun add_class pp (c, cs) classes =
   165   let
   166     val classes' = classes |> Graph.new_node (c, stamp ())
   167       handle Graph.DUP dup => err_dup_classes [dup];
   168     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   169       handle Graph.CYCLES css => err_cyclic_classes pp css;
   170   in classes'' end;
   171 
   172 fun add_classrel pp rel classes =
   173   classes |> Graph.add_edge_trans_acyclic rel
   174     handle Graph.CYCLES css => err_cyclic_classes pp css;
   175 
   176 fun merge_classes pp args : classes =
   177   Graph.merge_trans_acyclic (op =) args
   178     handle Graph.DUPS cs => err_dup_classes cs
   179         | Graph.CYCLES css => err_cyclic_classes pp css;
   180 
   181 end;
   182 
   183 
   184 (* arities *)
   185 
   186 local
   187 
   188 fun for_classes _ NONE = ""
   189   | for_classes pp (SOME (c1, c2)) =
   190       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   191 
   192 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   193   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   194     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   195     Pretty.string_of_arity pp (t, Ss', [c']));
   196 
   197 fun coregular pp C t (c, (c0, Ss)) ars =
   198   let
   199     fun conflict (c', (_, Ss')) =
   200       if class_le C (c, c') andalso not (sorts_le C (Ss, Ss')) then
   201         SOME ((c, c'), (c', Ss'))
   202       else if class_le C (c', c) andalso not (sorts_le C (Ss', Ss)) then
   203         SOME ((c', c), (c', Ss'))
   204       else NONE;
   205   in
   206     (case get_first conflict ars of
   207       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   208     | NONE => (c, (c0, Ss)) :: ars)
   209   end;
   210 
   211 fun insert pp C t (c, (c0, Ss)) ars =
   212   (case AList.lookup (op =) ars c of
   213     NONE => coregular pp C t (c, (c0, Ss)) ars
   214   | SOME (_, Ss') =>
   215       if sorts_le C (Ss, Ss') then ars
   216       else if sorts_le C (Ss', Ss) then
   217         coregular pp C t (c, (c0, Ss))
   218           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   219       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   220 
   221 fun complete C (c0, Ss) = map (rpair (c0, Ss)) (Graph.all_succs C [c0]);
   222 
   223 in
   224 
   225 fun add_arities pp classes (t, ars) arities =
   226   let val ars' =
   227     Symtab.lookup_list arities t
   228     |> fold_rev (fold_rev (insert pp classes t)) (map (complete classes) ars)
   229   in Symtab.update (t, ars') arities end;
   230 
   231 fun add_arities_table pp classes =
   232   Symtab.fold (fn (t, ars) => add_arities pp classes (t, map snd ars));
   233 
   234 fun rebuild_arities pp classes arities =
   235   Symtab.empty
   236   |> add_arities_table pp classes arities;
   237 
   238 fun merge_arities pp classes (arities1, arities2) =
   239   Symtab.empty
   240   |> add_arities_table pp classes arities1
   241   |> add_arities_table pp classes arities2;
   242 
   243 end;
   244 
   245 
   246 
   247 (** sorts of types **)
   248 
   249 (* errors *)
   250 
   251 datatype class_error = NoClassrel of class * class | NoArity of string * class;
   252 
   253 fun class_error pp (NoClassrel (c1, c2)) =
   254       error ("No class relation " ^ Pretty.string_of_classrel pp [c1, c2])
   255   | class_error pp (NoArity (a, c)) =
   256       error ("No type arity " ^ Pretty.string_of_arity pp (a, [], [c]));
   257 
   258 exception CLASS_ERROR of class_error;
   259 
   260 
   261 (* mg_domain *)
   262 
   263 fun mg_domain (classes, arities) a S =
   264   let
   265     fun dom c =
   266       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   267         NONE => raise CLASS_ERROR (NoArity (a, c))
   268       | SOME (_, Ss) => Ss);
   269     fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
   270   in
   271     (case S of
   272       [] => raise Fail "Unknown domain of empty intersection"
   273     | c :: cs => fold dom_inter cs (dom c))
   274   end;
   275 
   276 
   277 (* of_sort *)
   278 
   279 fun of_sort (classes, arities) =
   280   let
   281     fun ofS (_, []) = true
   282       | ofS (TFree (_, S), S') = sort_le classes (S, S')
   283       | ofS (TVar (_, S), S') = sort_le classes (S, S')
   284       | ofS (Type (a, Ts), S) =
   285           let val Ss = mg_domain (classes, arities) a S in
   286             ListPair.all ofS (Ts, Ss)
   287           end handle CLASS_ERROR _ => false;
   288   in ofS end;
   289 
   290 
   291 (* of_sort_derivation *)
   292 
   293 fun of_sort_derivation pp (classes, arities) {classrel, constructor, variable} =
   294   let
   295     fun weaken_path (x, c1 :: c2 :: cs) = weaken_path (classrel (x, c1) c2, c2 :: cs)
   296       | weaken_path (x, _) = x;
   297     fun weaken (x, c1) c2 =
   298       (case Graph.irreducible_paths classes (c1, c2) of
   299         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   300       | cs :: _ => weaken_path (x, cs));
   301 
   302     fun weakens S1 S2 = S2 |> map (fn c2 =>
   303       (case S1 |> find_first (fn (_, c1) => class_le classes (c1, c2)) of
   304         SOME d1 => weaken d1 c2
   305       | NONE => error ("Cannot derive subsort relation " ^
   306           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   307 
   308     fun derive _ [] = []
   309       | derive (Type (a, Ts)) S =
   310           let
   311             val Ss = mg_domain (classes, arities) a S;
   312             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   313           in
   314             S |> map (fn c =>
   315               let
   316                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   317                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   318               in weaken (constructor a dom' c0, c0) c end)
   319           end
   320       | derive T S = weakens (variable T) S;
   321   in uncurry derive end;
   322 
   323 
   324 (* witness_sorts *)
   325 
   326 fun witness_sorts (classes, arities) log_types hyps sorts =
   327   let
   328     fun le S1 S2 = sort_le classes (S1, S2);
   329     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   330     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   331     fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle CLASS_ERROR _ => NONE;
   332 
   333     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   334       | witn_sort path S (solved, failed) =
   335           if exists (le S) failed then (NONE, (solved, failed))
   336           else
   337             (case get_first (get_solved S) solved of
   338               SOME w => (SOME w, (solved, failed))
   339             | NONE =>
   340                 (case get_first (get_hyp S) hyps of
   341                   SOME w => (SOME w, (w :: solved, failed))
   342                 | NONE => witn_types path log_types S (solved, failed)))
   343 
   344     and witn_sorts path x = fold_map (witn_sort path) x
   345 
   346     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   347       | witn_types path (t :: ts) S solved_failed =
   348           (case mg_dom t S of
   349             SOME SS =>
   350               (*do not descend into stronger args (achieving termination)*)
   351               if exists (fn D => le D S orelse exists (le D) path) SS then
   352                 witn_types path ts S solved_failed
   353               else
   354                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   355                   if forall is_some ws then
   356                     let val w = (Type (t, map (#1 o the) ws), S)
   357                     in (SOME w, (w :: solved', failed')) end
   358                   else witn_types path ts S (solved', failed')
   359                 end
   360           | NONE => witn_types path ts S solved_failed);
   361 
   362     fun double_check TS =
   363       if of_sort (classes, arities) TS then TS
   364       else sys_error "FIXME Bad sort witness";
   365 
   366   in map_filter (Option.map double_check) (#1 (witn_sorts [] sorts ([], []))) end;
   367 
   368 end;