src/Pure/sorts.ML
author wenzelm
Thu Oct 04 20:29:42 2007 +0200 (2007-10-04)
changeset 24850 0cfd722ab579
parent 24732 08c2dd5378c7
child 26326 a68045977f60
permissions -rw-r--r--
Name.uu, Name.aT;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val union: sort list -> sort list -> sort list
    19   val subtract: sort list -> sort list -> sort list
    20   val remove_sort: sort -> sort list -> sort list
    21   val insert_sort: sort -> sort list -> sort list
    22   val insert_typ: typ -> sort list -> sort list
    23   val insert_typs: typ list -> sort list -> sort list
    24   val insert_term: term -> sort list -> sort list
    25   val insert_terms: term list -> sort list -> sort list
    26   type algebra
    27   val rep_algebra: algebra ->
    28    {classes: serial Graph.T,
    29     arities: (class * (class * sort list)) list Symtab.table}
    30   val all_classes: algebra -> class list
    31   val minimal_classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val minimize_sort: algebra -> sort -> sort
    40   val complete_sort: algebra -> sort -> sort
    41   val certify_class: algebra -> class -> class    (*exception TYPE*)
    42   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    43   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    44   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    45   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    46   val empty_algebra: algebra
    47   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    48   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    49     -> algebra -> (sort -> sort) * algebra
    50   type class_error
    51   val msg_class_error: Pretty.pp -> class_error -> string
    52   val class_error: Pretty.pp -> class_error -> 'a
    53   exception CLASS_ERROR of class_error
    54   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    55   val of_sort: algebra -> typ * sort -> bool
    56   val of_sort_derivation: Pretty.pp -> algebra ->
    57     {class_relation: 'a * class -> class -> 'a,
    58      type_constructor: string -> ('a * class) list list -> class -> 'a,
    59      type_variable: typ -> ('a * class) list} ->
    60     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    61   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    62 end;
    63 
    64 structure Sorts: SORTS =
    65 struct
    66 
    67 
    68 (** ordered lists of sorts **)
    69 
    70 val op union = OrdList.union Term.sort_ord;
    71 val subtract = OrdList.subtract Term.sort_ord;
    72 
    73 val remove_sort = OrdList.remove Term.sort_ord;
    74 val insert_sort = OrdList.insert Term.sort_ord;
    75 
    76 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    77   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    78   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    79 and insert_typs [] Ss = Ss
    80   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    81 
    82 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    83   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    84   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    85   | insert_term (Bound _) Ss = Ss
    86   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    87   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    88 
    89 fun insert_terms [] Ss = Ss
    90   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    91 
    92 
    93 
    94 (** order-sorted algebra **)
    95 
    96 (*
    97   classes: graph representing class declarations together with proper
    98     subclass relation, which needs to be transitive and acyclic.
    99 
   100   arities: table of association lists of all type arities; (t, ars)
   101     means that type constructor t has the arities ars; an element
   102     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   103     via c0 <= c.  "Coregularity" of the arities structure requires
   104     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   105     c1 <= c2 holds Ss1 <= Ss2.
   106 *)
   107 
   108 datatype algebra = Algebra of
   109  {classes: serial Graph.T,
   110   arities: (class * (class * sort list)) list Symtab.table};
   111 
   112 fun rep_algebra (Algebra args) = args;
   113 
   114 val classes_of = #classes o rep_algebra;
   115 val arities_of = #arities o rep_algebra;
   116 
   117 fun make_algebra (classes, arities) =
   118   Algebra {classes = classes, arities = arities};
   119 
   120 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   121 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   122 
   123 
   124 (* classes *)
   125 
   126 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   127 
   128 val minimal_classes = Graph.minimals o classes_of;
   129 val super_classes = Graph.imm_succs o classes_of;
   130 
   131 
   132 (* class relations *)
   133 
   134 val class_less = Graph.is_edge o classes_of;
   135 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   136 
   137 
   138 (* sort relations *)
   139 
   140 fun sort_le algebra (S1, S2) =
   141   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   142 
   143 fun sorts_le algebra (Ss1, Ss2) =
   144   ListPair.all (sort_le algebra) (Ss1, Ss2);
   145 
   146 fun sort_eq algebra (S1, S2) =
   147   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   148 
   149 
   150 (* intersection *)
   151 
   152 fun inter_class algebra c S =
   153   let
   154     fun intr [] = [c]
   155       | intr (S' as c' :: c's) =
   156           if class_le algebra (c', c) then S'
   157           else if class_le algebra (c, c') then intr c's
   158           else c' :: intr c's
   159   in intr S end;
   160 
   161 fun inter_sort algebra (S1, S2) =
   162   sort_strings (fold (inter_class algebra) S1 S2);
   163 
   164 
   165 (* normal forms *)
   166 
   167 fun minimize_sort _ [] = []
   168   | minimize_sort _ (S as [_]) = S
   169   | minimize_sort algebra S =
   170       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   171       |> sort_distinct string_ord;
   172 
   173 fun complete_sort algebra =
   174   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   175 
   176 
   177 (* certify *)
   178 
   179 fun certify_class algebra c =
   180   if can (Graph.get_node (classes_of algebra)) c then c
   181   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   182 
   183 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   184 
   185 
   186 
   187 (** build algebras **)
   188 
   189 (* classes *)
   190 
   191 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   192 
   193 fun err_cyclic_classes pp css =
   194   error (cat_lines (map (fn cs =>
   195     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   196 
   197 fun add_class pp (c, cs) = map_classes (fn classes =>
   198   let
   199     val classes' = classes |> Graph.new_node (c, serial ())
   200       handle Graph.DUP dup => err_dup_class dup;
   201     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   202       handle Graph.CYCLES css => err_cyclic_classes pp css;
   203   in classes'' end);
   204 
   205 
   206 (* arities *)
   207 
   208 local
   209 
   210 fun for_classes _ NONE = ""
   211   | for_classes pp (SOME (c1, c2)) =
   212       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   213 
   214 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   215   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   216     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   217     Pretty.string_of_arity pp (t, Ss', [c']));
   218 
   219 fun coregular pp algebra t (c, (c0, Ss)) ars =
   220   let
   221     fun conflict (c', (_, Ss')) =
   222       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   223         SOME ((c, c'), (c', Ss'))
   224       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   225         SOME ((c', c), (c', Ss'))
   226       else NONE;
   227   in
   228     (case get_first conflict ars of
   229       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   230     | NONE => (c, (c0, Ss)) :: ars)
   231   end;
   232 
   233 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   234 
   235 fun insert pp algebra t (c, (c0, Ss)) ars =
   236   (case AList.lookup (op =) ars c of
   237     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   238   | SOME (_, Ss') =>
   239       if sorts_le algebra (Ss, Ss') then ars
   240       else if sorts_le algebra (Ss', Ss) then
   241         coregular pp algebra t (c, (c0, Ss))
   242           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   243       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   244 
   245 fun insert_ars pp algebra (t, ars) arities =
   246   let val ars' =
   247     Symtab.lookup_list arities t
   248     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   249   in Symtab.update (t, ars') arities end;
   250 
   251 in
   252 
   253 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   254 
   255 fun add_arities_table pp algebra =
   256   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   257 
   258 end;
   259 
   260 
   261 (* classrel *)
   262 
   263 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   264   Symtab.empty
   265   |> add_arities_table pp algebra arities);
   266 
   267 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   268   classes |> Graph.add_edge_trans_acyclic rel
   269     handle Graph.CYCLES css => err_cyclic_classes pp css);
   270 
   271 
   272 (* empty and merge *)
   273 
   274 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   275 
   276 fun merge_algebra pp
   277    (Algebra {classes = classes1, arities = arities1},
   278     Algebra {classes = classes2, arities = arities2}) =
   279   let
   280     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   281       handle Graph.DUP c => err_dup_class c
   282           | Graph.CYCLES css => err_cyclic_classes pp css;
   283     val algebra0 = make_algebra (classes', Symtab.empty);
   284     val arities' = Symtab.empty
   285       |> add_arities_table pp algebra0 arities1
   286       |> add_arities_table pp algebra0 arities2;
   287   in make_algebra (classes', arities') end;
   288 
   289 
   290 (* subalgebra *)
   291 
   292 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   293   let
   294     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   295     fun restrict_arity tyco (c, (_, Ss)) =
   296       if P c then
   297         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   298           |> map restrict_sort))
   299       else NONE;
   300     val classes' = classes |> Graph.subgraph P;
   301     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   302   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   303 
   304 
   305 
   306 (** sorts of types **)
   307 
   308 (* errors *)
   309 
   310 datatype class_error = NoClassrel of class * class | NoArity of string * class;
   311 
   312 fun msg_class_error pp (NoClassrel (c1, c2)) =
   313       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   314   | msg_class_error pp (NoArity (a, c)) =
   315       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c]);
   316 
   317 fun class_error pp = error o msg_class_error pp;
   318 
   319 exception CLASS_ERROR of class_error;
   320 
   321 
   322 (* mg_domain *)
   323 
   324 fun mg_domain algebra a S =
   325   let
   326     val arities = arities_of algebra;
   327     fun dom c =
   328       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   329         NONE => raise CLASS_ERROR (NoArity (a, c))
   330       | SOME (_, Ss) => Ss);
   331     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   332   in
   333     (case S of
   334       [] => raise Fail "Unknown domain of empty intersection"
   335     | c :: cs => fold dom_inter cs (dom c))
   336   end;
   337 
   338 
   339 (* of_sort *)
   340 
   341 fun of_sort algebra =
   342   let
   343     fun ofS (_, []) = true
   344       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   345       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   346       | ofS (Type (a, Ts), S) =
   347           let val Ss = mg_domain algebra a S in
   348             ListPair.all ofS (Ts, Ss)
   349           end handle CLASS_ERROR _ => false;
   350   in ofS end;
   351 
   352 
   353 (* of_sort_derivation *)
   354 
   355 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   356   let
   357     val {classes, arities} = rep_algebra algebra;
   358     fun weaken_path (x, c1 :: c2 :: cs) =
   359           weaken_path (class_relation (x, c1) c2, c2 :: cs)
   360       | weaken_path (x, _) = x;
   361     fun weaken (x, c1) c2 =
   362       (case Graph.irreducible_paths classes (c1, c2) of
   363         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   364       | cs :: _ => weaken_path (x, cs));
   365 
   366     fun weakens S1 S2 = S2 |> map (fn c2 =>
   367       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   368         SOME d1 => weaken d1 c2
   369       | NONE => error ("Cannot derive subsort relation " ^
   370           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   371 
   372     fun derive _ [] = []
   373       | derive (Type (a, Ts)) S =
   374           let
   375             val Ss = mg_domain algebra a S;
   376             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   377           in
   378             S |> map (fn c =>
   379               let
   380                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   381                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   382               in weaken (type_constructor a dom' c0, c0) c end)
   383           end
   384       | derive T S = weakens (type_variable T) S;
   385   in uncurry derive end;
   386 
   387 
   388 (* witness_sorts *)
   389 
   390 fun witness_sorts algebra types hyps sorts =
   391   let
   392     fun le S1 S2 = sort_le algebra (S1, S2);
   393     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   394     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   395     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   396 
   397     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   398       | witn_sort path S (solved, failed) =
   399           if exists (le S) failed then (NONE, (solved, failed))
   400           else
   401             (case get_first (get_solved S) solved of
   402               SOME w => (SOME w, (solved, failed))
   403             | NONE =>
   404                 (case get_first (get_hyp S) hyps of
   405                   SOME w => (SOME w, (w :: solved, failed))
   406                 | NONE => witn_types path types S (solved, failed)))
   407 
   408     and witn_sorts path x = fold_map (witn_sort path) x
   409 
   410     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   411       | witn_types path (t :: ts) S solved_failed =
   412           (case mg_dom t S of
   413             SOME SS =>
   414               (*do not descend into stronger args (achieving termination)*)
   415               if exists (fn D => le D S orelse exists (le D) path) SS then
   416                 witn_types path ts S solved_failed
   417               else
   418                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   419                   if forall is_some ws then
   420                     let val w = (Type (t, map (#1 o the) ws), S)
   421                     in (SOME w, (w :: solved', failed')) end
   422                   else witn_types path ts S (solved', failed')
   423                 end
   424           | NONE => witn_types path ts S solved_failed);
   425 
   426   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   427 
   428 end;