src/Pure/sorts.ML
author haftmann
Wed Apr 02 15:58:40 2008 +0200 (2008-04-02)
changeset 26517 ef036a63f6e9
parent 26326 a68045977f60
child 26639 75ea79a50326
permissions -rw-r--r--
canonical meet_sort operation
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val union: sort list -> sort list -> sort list
    19   val subtract: sort list -> sort list -> sort list
    20   val remove_sort: sort -> sort list -> sort list
    21   val insert_sort: sort -> sort list -> sort list
    22   val insert_typ: typ -> sort list -> sort list
    23   val insert_typs: typ list -> sort list -> sort list
    24   val insert_term: term -> sort list -> sort list
    25   val insert_terms: term list -> sort list -> sort list
    26   type algebra
    27   val rep_algebra: algebra ->
    28    {classes: serial Graph.T,
    29     arities: (class * (class * sort list)) list Symtab.table}
    30   val all_classes: algebra -> class list
    31   val minimal_classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val minimize_sort: algebra -> sort -> sort
    40   val complete_sort: algebra -> sort -> sort
    41   val certify_class: algebra -> class -> class    (*exception TYPE*)
    42   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    43   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    44   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    45   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    46   val empty_algebra: algebra
    47   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    48   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    49     -> algebra -> (sort -> sort) * algebra
    50   type class_error
    51   val msg_class_error: Pretty.pp -> class_error -> string
    52   val class_error: Pretty.pp -> class_error -> 'a
    53   exception CLASS_ERROR of class_error
    54   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    55   val of_sort: algebra -> typ * sort -> bool
    56   val meet_sort: algebra -> typ * sort -> sort Vartab.table -> sort Vartab.table
    57   val of_sort_derivation: Pretty.pp -> algebra ->
    58     {class_relation: 'a * class -> class -> 'a,
    59      type_constructor: string -> ('a * class) list list -> class -> 'a,
    60      type_variable: typ -> ('a * class) list} ->
    61     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    62   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    63 end;
    64 
    65 structure Sorts: SORTS =
    66 struct
    67 
    68 
    69 (** ordered lists of sorts **)
    70 
    71 val op union = OrdList.union Term.sort_ord;
    72 val subtract = OrdList.subtract Term.sort_ord;
    73 
    74 val remove_sort = OrdList.remove Term.sort_ord;
    75 val insert_sort = OrdList.insert Term.sort_ord;
    76 
    77 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    78   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    79   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    80 and insert_typs [] Ss = Ss
    81   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    82 
    83 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    84   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    85   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    86   | insert_term (Bound _) Ss = Ss
    87   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    88   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    89 
    90 fun insert_terms [] Ss = Ss
    91   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    92 
    93 
    94 
    95 (** order-sorted algebra **)
    96 
    97 (*
    98   classes: graph representing class declarations together with proper
    99     subclass relation, which needs to be transitive and acyclic.
   100 
   101   arities: table of association lists of all type arities; (t, ars)
   102     means that type constructor t has the arities ars; an element
   103     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   104     via c0 <= c.  "Coregularity" of the arities structure requires
   105     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   106     c1 <= c2 holds Ss1 <= Ss2.
   107 *)
   108 
   109 datatype algebra = Algebra of
   110  {classes: serial Graph.T,
   111   arities: (class * (class * sort list)) list Symtab.table};
   112 
   113 fun rep_algebra (Algebra args) = args;
   114 
   115 val classes_of = #classes o rep_algebra;
   116 val arities_of = #arities o rep_algebra;
   117 
   118 fun make_algebra (classes, arities) =
   119   Algebra {classes = classes, arities = arities};
   120 
   121 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   122 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   123 
   124 
   125 (* classes *)
   126 
   127 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   128 
   129 val minimal_classes = Graph.minimals o classes_of;
   130 val super_classes = Graph.imm_succs o classes_of;
   131 
   132 
   133 (* class relations *)
   134 
   135 val class_less = Graph.is_edge o classes_of;
   136 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   137 
   138 
   139 (* sort relations *)
   140 
   141 fun sort_le algebra (S1, S2) =
   142   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   143 
   144 fun sorts_le algebra (Ss1, Ss2) =
   145   ListPair.all (sort_le algebra) (Ss1, Ss2);
   146 
   147 fun sort_eq algebra (S1, S2) =
   148   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   149 
   150 
   151 (* intersection *)
   152 
   153 fun inter_class algebra c S =
   154   let
   155     fun intr [] = [c]
   156       | intr (S' as c' :: c's) =
   157           if class_le algebra (c', c) then S'
   158           else if class_le algebra (c, c') then intr c's
   159           else c' :: intr c's
   160   in intr S end;
   161 
   162 fun inter_sort algebra (S1, S2) =
   163   sort_strings (fold (inter_class algebra) S1 S2);
   164 
   165 
   166 (* normal forms *)
   167 
   168 fun minimize_sort _ [] = []
   169   | minimize_sort _ (S as [_]) = S
   170   | minimize_sort algebra S =
   171       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   172       |> sort_distinct string_ord;
   173 
   174 fun complete_sort algebra =
   175   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   176 
   177 
   178 (* certify *)
   179 
   180 fun certify_class algebra c =
   181   if can (Graph.get_node (classes_of algebra)) c then c
   182   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   183 
   184 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   185 
   186 
   187 
   188 (** build algebras **)
   189 
   190 (* classes *)
   191 
   192 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   193 
   194 fun err_cyclic_classes pp css =
   195   error (cat_lines (map (fn cs =>
   196     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   197 
   198 fun add_class pp (c, cs) = map_classes (fn classes =>
   199   let
   200     val classes' = classes |> Graph.new_node (c, serial ())
   201       handle Graph.DUP dup => err_dup_class dup;
   202     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   203       handle Graph.CYCLES css => err_cyclic_classes pp css;
   204   in classes'' end);
   205 
   206 
   207 (* arities *)
   208 
   209 local
   210 
   211 fun for_classes _ NONE = ""
   212   | for_classes pp (SOME (c1, c2)) =
   213       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   214 
   215 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   216   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   217     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   218     Pretty.string_of_arity pp (t, Ss', [c']));
   219 
   220 fun coregular pp algebra t (c, (c0, Ss)) ars =
   221   let
   222     fun conflict (c', (_, Ss')) =
   223       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   224         SOME ((c, c'), (c', Ss'))
   225       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   226         SOME ((c', c), (c', Ss'))
   227       else NONE;
   228   in
   229     (case get_first conflict ars of
   230       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   231     | NONE => (c, (c0, Ss)) :: ars)
   232   end;
   233 
   234 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   235 
   236 fun insert pp algebra t (c, (c0, Ss)) ars =
   237   (case AList.lookup (op =) ars c of
   238     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   239   | SOME (_, Ss') =>
   240       if sorts_le algebra (Ss, Ss') then ars
   241       else if sorts_le algebra (Ss', Ss) then
   242         coregular pp algebra t (c, (c0, Ss))
   243           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   244       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   245 
   246 fun insert_ars pp algebra (t, ars) arities =
   247   let val ars' =
   248     Symtab.lookup_list arities t
   249     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   250   in Symtab.update (t, ars') arities end;
   251 
   252 in
   253 
   254 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   255 
   256 fun add_arities_table pp algebra =
   257   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   258 
   259 end;
   260 
   261 
   262 (* classrel *)
   263 
   264 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   265   Symtab.empty
   266   |> add_arities_table pp algebra arities);
   267 
   268 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   269   classes |> Graph.add_edge_trans_acyclic rel
   270     handle Graph.CYCLES css => err_cyclic_classes pp css);
   271 
   272 
   273 (* empty and merge *)
   274 
   275 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   276 
   277 fun merge_algebra pp
   278    (Algebra {classes = classes1, arities = arities1},
   279     Algebra {classes = classes2, arities = arities2}) =
   280   let
   281     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   282       handle Graph.DUP c => err_dup_class c
   283           | Graph.CYCLES css => err_cyclic_classes pp css;
   284     val algebra0 = make_algebra (classes', Symtab.empty);
   285     val arities' = Symtab.empty
   286       |> add_arities_table pp algebra0 arities1
   287       |> add_arities_table pp algebra0 arities2;
   288   in make_algebra (classes', arities') end;
   289 
   290 
   291 (* subalgebra *)
   292 
   293 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   294   let
   295     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   296     fun restrict_arity tyco (c, (_, Ss)) =
   297       if P c then
   298         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   299           |> map restrict_sort))
   300       else NONE;
   301     val classes' = classes |> Graph.subgraph P;
   302     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   303   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   304 
   305 
   306 
   307 (** sorts of types **)
   308 
   309 (* errors *)
   310 
   311 datatype class_error = NoClassrel of class * class
   312   | NoArity of string * class
   313   | NoSubsort of sort * sort
   314 
   315 fun msg_class_error pp (NoClassrel (c1, c2)) =
   316       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   317   | msg_class_error pp (NoArity (a, c)) =
   318       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   319   | msg_class_error pp (NoSubsort (s1, s2)) =
   320       Pretty.string_of_sort pp s1 ^ " is not subsort of " ^ Pretty.string_of_sort pp s2;
   321 
   322 fun class_error pp = error o msg_class_error pp;
   323 
   324 exception CLASS_ERROR of class_error;
   325 
   326 
   327 (* mg_domain *)
   328 
   329 fun mg_domain algebra a S =
   330   let
   331     val arities = arities_of algebra;
   332     fun dom c =
   333       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   334         NONE => raise CLASS_ERROR (NoArity (a, c))
   335       | SOME (_, Ss) => Ss);
   336     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   337   in
   338     (case S of
   339       [] => raise Fail "Unknown domain of empty intersection"
   340     | c :: cs => fold dom_inter cs (dom c))
   341   end;
   342 
   343 
   344 (* of_sort *)
   345 
   346 fun of_sort algebra =
   347   let
   348     fun ofS (_, []) = true
   349       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   350       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   351       | ofS (Type (a, Ts), S) =
   352           let val Ss = mg_domain algebra a S in
   353             ListPair.all ofS (Ts, Ss)
   354           end handle CLASS_ERROR _ => false;
   355   in ofS end;
   356 
   357 
   358 (* meet_sort *)
   359 
   360 fun meet_sort algebra =
   361   let
   362     fun meet _ [] = I
   363       | meet (TFree (_, S)) S' = if sort_le algebra (S, S') then I
   364           else raise CLASS_ERROR (NoSubsort (S, S'))
   365       | meet (TVar (v, S)) S' = if sort_le algebra (S, S') then I
   366           else Vartab.map_default (v, S) (curry (inter_sort algebra) S')
   367       | meet (Type (a, Ts)) S =
   368           fold2 meet Ts (mg_domain algebra a S)
   369   in uncurry meet end;
   370 
   371 
   372 (* of_sort_derivation *)
   373 
   374 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   375   let
   376     val {classes, arities} = rep_algebra algebra;
   377     fun weaken_path (x, c1 :: c2 :: cs) =
   378           weaken_path (class_relation (x, c1) c2, c2 :: cs)
   379       | weaken_path (x, _) = x;
   380     fun weaken (x, c1) c2 =
   381       (case Graph.irreducible_paths classes (c1, c2) of
   382         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   383       | cs :: _ => weaken_path (x, cs));
   384 
   385     fun weakens S1 S2 = S2 |> map (fn c2 =>
   386       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   387         SOME d1 => weaken d1 c2
   388       | NONE => error ("Cannot derive subsort relation " ^
   389           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   390 
   391     fun derive _ [] = []
   392       | derive (Type (a, Ts)) S =
   393           let
   394             val Ss = mg_domain algebra a S;
   395             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   396           in
   397             S |> map (fn c =>
   398               let
   399                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   400                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   401               in weaken (type_constructor a dom' c0, c0) c end)
   402           end
   403       | derive T S = weakens (type_variable T) S;
   404   in uncurry derive end;
   405 
   406 
   407 (* witness_sorts *)
   408 
   409 fun witness_sorts algebra types hyps sorts =
   410   let
   411     fun le S1 S2 = sort_le algebra (S1, S2);
   412     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   413     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   414     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   415 
   416     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   417       | witn_sort path S (solved, failed) =
   418           if exists (le S) failed then (NONE, (solved, failed))
   419           else
   420             (case get_first (get_solved S) solved of
   421               SOME w => (SOME w, (solved, failed))
   422             | NONE =>
   423                 (case get_first (get_hyp S) hyps of
   424                   SOME w => (SOME w, (w :: solved, failed))
   425                 | NONE => witn_types path types S (solved, failed)))
   426 
   427     and witn_sorts path x = fold_map (witn_sort path) x
   428 
   429     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   430       | witn_types path (t :: ts) S solved_failed =
   431           (case mg_dom t S of
   432             SOME SS =>
   433               (*do not descend into stronger args (achieving termination)*)
   434               if exists (fn D => le D S orelse exists (le D) path) SS then
   435                 witn_types path ts S solved_failed
   436               else
   437                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   438                   if forall is_some ws then
   439                     let val w = (Type (t, map (#1 o the) ws), S)
   440                     in (SOME w, (w :: solved', failed')) end
   441                   else witn_types path ts S (solved', failed')
   442                 end
   443           | NONE => witn_types path ts S solved_failed);
   444 
   445   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   446 
   447 end;