src/Pure/sorts.ML
author wenzelm
Fri Jul 11 23:17:25 2008 +0200 (2008-07-11)
changeset 27555 dfda3192dec2
parent 27548 8178db4b25f3
child 28354 c5fe7372ae4e
permissions -rw-r--r--
Sorts.weaken: abstract argument;
tuned;
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val union: sort list -> sort list -> sort list
    19   val subtract: sort list -> sort list -> sort list
    20   val remove_sort: sort -> sort list -> sort list
    21   val insert_sort: sort -> sort list -> sort list
    22   val insert_typ: typ -> sort list -> sort list
    23   val insert_typs: typ list -> sort list -> sort list
    24   val insert_term: term -> sort list -> sort list
    25   val insert_terms: term list -> sort list -> sort list
    26   type algebra
    27   val rep_algebra: algebra ->
    28    {classes: serial Graph.T,
    29     arities: (class * (class * sort list)) list Symtab.table}
    30   val all_classes: algebra -> class list
    31   val super_classes: algebra -> class -> class list
    32   val class_less: algebra -> class * class -> bool
    33   val class_le: algebra -> class * class -> bool
    34   val sort_eq: algebra -> sort * sort -> bool
    35   val sort_le: algebra -> sort * sort -> bool
    36   val sorts_le: algebra -> sort list * sort list -> bool
    37   val inter_sort: algebra -> sort * sort -> sort
    38   val minimize_sort: algebra -> sort -> sort
    39   val complete_sort: algebra -> sort -> sort
    40   val certify_class: algebra -> class -> class    (*exception TYPE*)
    41   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    42   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    43   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    44   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    45   val empty_algebra: algebra
    46   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    47   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list)
    48     -> algebra -> (sort -> sort) * algebra
    49   type class_error
    50   val class_error: Pretty.pp -> class_error -> string
    51   exception CLASS_ERROR of class_error
    52   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    53   val meet_sort: algebra -> typ * sort -> sort Vartab.table -> sort Vartab.table
    54   val of_sort: algebra -> typ * sort -> bool
    55   val weaken: algebra -> ('a * class -> class -> 'a) -> 'a * class -> class -> 'a
    56   val of_sort_derivation: Pretty.pp -> algebra ->
    57     {class_relation: 'a * class -> class -> 'a,
    58      type_constructor: string -> ('a * class) list list -> class -> 'a,
    59      type_variable: typ -> ('a * class) list} ->
    60     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    61   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    62 end;
    63 
    64 structure Sorts: SORTS =
    65 struct
    66 
    67 
    68 (** ordered lists of sorts **)
    69 
    70 val op union = OrdList.union Term.sort_ord;
    71 val subtract = OrdList.subtract Term.sort_ord;
    72 
    73 val remove_sort = OrdList.remove Term.sort_ord;
    74 val insert_sort = OrdList.insert Term.sort_ord;
    75 
    76 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    77   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    78   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    79 and insert_typs [] Ss = Ss
    80   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    81 
    82 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    83   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    84   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    85   | insert_term (Bound _) Ss = Ss
    86   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    87   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    88 
    89 fun insert_terms [] Ss = Ss
    90   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    91 
    92 
    93 
    94 (** order-sorted algebra **)
    95 
    96 (*
    97   classes: graph representing class declarations together with proper
    98     subclass relation, which needs to be transitive and acyclic.
    99 
   100   arities: table of association lists of all type arities; (t, ars)
   101     means that type constructor t has the arities ars; an element
   102     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   103     via c0 <= c.  "Coregularity" of the arities structure requires
   104     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   105     c1 <= c2 holds Ss1 <= Ss2.
   106 *)
   107 
   108 datatype algebra = Algebra of
   109  {classes: serial Graph.T,
   110   arities: (class * (class * sort list)) list Symtab.table};
   111 
   112 fun rep_algebra (Algebra args) = args;
   113 
   114 val classes_of = #classes o rep_algebra;
   115 val arities_of = #arities o rep_algebra;
   116 
   117 fun make_algebra (classes, arities) =
   118   Algebra {classes = classes, arities = arities};
   119 
   120 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   121 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   122 
   123 
   124 (* classes *)
   125 
   126 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   127 
   128 val super_classes = Graph.imm_succs o classes_of;
   129 
   130 
   131 (* class relations *)
   132 
   133 val class_less = Graph.is_edge o classes_of;
   134 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   135 
   136 
   137 (* sort relations *)
   138 
   139 fun sort_le algebra (S1, S2) =
   140   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   141 
   142 fun sorts_le algebra (Ss1, Ss2) =
   143   ListPair.all (sort_le algebra) (Ss1, Ss2);
   144 
   145 fun sort_eq algebra (S1, S2) =
   146   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   147 
   148 
   149 (* intersection *)
   150 
   151 fun inter_class algebra c S =
   152   let
   153     fun intr [] = [c]
   154       | intr (S' as c' :: c's) =
   155           if class_le algebra (c', c) then S'
   156           else if class_le algebra (c, c') then intr c's
   157           else c' :: intr c's
   158   in intr S end;
   159 
   160 fun inter_sort algebra (S1, S2) =
   161   sort_strings (fold (inter_class algebra) S1 S2);
   162 
   163 
   164 (* normal forms *)
   165 
   166 fun minimize_sort _ [] = []
   167   | minimize_sort _ (S as [_]) = S
   168   | minimize_sort algebra S =
   169       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   170       |> sort_distinct string_ord;
   171 
   172 fun complete_sort algebra =
   173   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   174 
   175 
   176 (* certify *)
   177 
   178 fun certify_class algebra c =
   179   if can (Graph.get_node (classes_of algebra)) c then c
   180   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   181 
   182 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   183 
   184 
   185 
   186 (** build algebras **)
   187 
   188 (* classes *)
   189 
   190 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   191 
   192 fun err_cyclic_classes pp css =
   193   error (cat_lines (map (fn cs =>
   194     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   195 
   196 fun add_class pp (c, cs) = map_classes (fn classes =>
   197   let
   198     val classes' = classes |> Graph.new_node (c, serial ())
   199       handle Graph.DUP dup => err_dup_class dup;
   200     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   201       handle Graph.CYCLES css => err_cyclic_classes pp css;
   202   in classes'' end);
   203 
   204 
   205 (* arities *)
   206 
   207 local
   208 
   209 fun for_classes _ NONE = ""
   210   | for_classes pp (SOME (c1, c2)) =
   211       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   212 
   213 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   214   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   215     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   216     Pretty.string_of_arity pp (t, Ss', [c']));
   217 
   218 fun coregular pp algebra t (c, (c0, Ss)) ars =
   219   let
   220     fun conflict (c', (_, Ss')) =
   221       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   222         SOME ((c, c'), (c', Ss'))
   223       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   224         SOME ((c', c), (c', Ss'))
   225       else NONE;
   226   in
   227     (case get_first conflict ars of
   228       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   229     | NONE => (c, (c0, Ss)) :: ars)
   230   end;
   231 
   232 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   233 
   234 fun insert pp algebra t (c, (c0, Ss)) ars =
   235   (case AList.lookup (op =) ars c of
   236     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   237   | SOME (_, Ss') =>
   238       if sorts_le algebra (Ss, Ss') then ars
   239       else if sorts_le algebra (Ss', Ss) then
   240         coregular pp algebra t (c, (c0, Ss))
   241           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   242       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   243 
   244 fun insert_ars pp algebra (t, ars) arities =
   245   let val ars' =
   246     Symtab.lookup_list arities t
   247     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   248   in Symtab.update (t, ars') arities end;
   249 
   250 in
   251 
   252 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   253 
   254 fun add_arities_table pp algebra =
   255   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   256 
   257 end;
   258 
   259 
   260 (* classrel *)
   261 
   262 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   263   Symtab.empty
   264   |> add_arities_table pp algebra arities);
   265 
   266 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   267   classes |> Graph.add_edge_trans_acyclic rel
   268     handle Graph.CYCLES css => err_cyclic_classes pp css);
   269 
   270 
   271 (* empty and merge *)
   272 
   273 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   274 
   275 fun merge_algebra pp
   276    (Algebra {classes = classes1, arities = arities1},
   277     Algebra {classes = classes2, arities = arities2}) =
   278   let
   279     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   280       handle Graph.DUP c => err_dup_class c
   281           | Graph.CYCLES css => err_cyclic_classes pp css;
   282     val algebra0 = make_algebra (classes', Symtab.empty);
   283     val arities' = Symtab.empty
   284       |> add_arities_table pp algebra0 arities1
   285       |> add_arities_table pp algebra0 arities2;
   286   in make_algebra (classes', arities') end;
   287 
   288 
   289 (* subalgebra *)
   290 
   291 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   292   let
   293     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   294     fun restrict_arity tyco (c, (_, Ss)) =
   295       if P c then
   296         SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) (sargs (c, tyco))
   297           |> map restrict_sort))
   298       else NONE;
   299     val classes' = classes |> Graph.subgraph P;
   300     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   301   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   302 
   303 
   304 
   305 (** sorts of types **)
   306 
   307 (* errors -- delayed message composition *)
   308 
   309 datatype class_error =
   310   NoClassrel of class * class |
   311   NoArity of string * class |
   312   NoSubsort of sort * sort;
   313 
   314 fun class_error pp (NoClassrel (c1, c2)) =
   315       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   316   | class_error pp (NoArity (a, c)) =
   317       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   318   | class_error pp (NoSubsort (S1, S2)) =
   319      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   320        ^ " < " ^ Pretty.string_of_sort pp S2;
   321 
   322 exception CLASS_ERROR of class_error;
   323 
   324 
   325 (* mg_domain *)
   326 
   327 fun mg_domain algebra a S =
   328   let
   329     val arities = arities_of algebra;
   330     fun dom c =
   331       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   332         NONE => raise CLASS_ERROR (NoArity (a, c))
   333       | SOME (_, Ss) => Ss);
   334     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   335   in
   336     (case S of
   337       [] => raise Fail "Unknown domain of empty intersection"
   338     | c :: cs => fold dom_inter cs (dom c))
   339   end;
   340 
   341 
   342 (* meet_sort *)
   343 
   344 fun meet_sort algebra =
   345   let
   346     fun inters S S' = inter_sort algebra (S, S');
   347     fun meet _ [] = I
   348       | meet (TFree (_, S)) S' =
   349           if sort_le algebra (S, S') then I
   350           else raise CLASS_ERROR (NoSubsort (S, S'))
   351       | meet (TVar (v, S)) S' =
   352           if sort_le algebra (S, S') then I
   353           else Vartab.map_default (v, S) (inters S')
   354       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   355   in uncurry meet end;
   356 
   357 
   358 (* of_sort *)
   359 
   360 fun of_sort algebra =
   361   let
   362     fun ofS (_, []) = true
   363       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   364       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   365       | ofS (Type (a, Ts), S) =
   366           let val Ss = mg_domain algebra a S in
   367             ListPair.all ofS (Ts, Ss)
   368           end handle CLASS_ERROR _ => false;
   369   in ofS end;
   370 
   371 
   372 (* animating derivations *)
   373 
   374 fun weaken algebra class_relation =
   375   let
   376     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   377       | path (x, _) = x;
   378   in fn (x, c1) => fn c2 =>
   379     (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   380       [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   381     | cs :: _ => path (x, cs))
   382   end;
   383 
   384 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   385   let
   386     val weaken = weaken algebra class_relation;
   387     val arities = arities_of algebra;
   388 
   389     fun weakens S1 S2 = S2 |> map (fn c2 =>
   390       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   391         SOME d1 => weaken d1 c2
   392       | NONE => raise CLASS_ERROR (NoSubsort (map #2 S1, S2))));
   393 
   394     fun derive _ [] = []
   395       | derive (Type (a, Ts)) S =
   396           let
   397             val Ss = mg_domain algebra a S;
   398             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   399           in
   400             S |> map (fn c =>
   401               let
   402                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   403                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   404               in weaken (type_constructor a dom' c0, c0) c end)
   405           end
   406       | derive T S = weakens (type_variable T) S;
   407   in uncurry derive end;
   408 
   409 
   410 (* witness_sorts *)
   411 
   412 fun witness_sorts algebra types hyps sorts =
   413   let
   414     fun le S1 S2 = sort_le algebra (S1, S2);
   415     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   416     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   417     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   418 
   419     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   420       | witn_sort path S (solved, failed) =
   421           if exists (le S) failed then (NONE, (solved, failed))
   422           else
   423             (case get_first (get_solved S) solved of
   424               SOME w => (SOME w, (solved, failed))
   425             | NONE =>
   426                 (case get_first (get_hyp S) hyps of
   427                   SOME w => (SOME w, (w :: solved, failed))
   428                 | NONE => witn_types path types S (solved, failed)))
   429 
   430     and witn_sorts path x = fold_map (witn_sort path) x
   431 
   432     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   433       | witn_types path (t :: ts) S solved_failed =
   434           (case mg_dom t S of
   435             SOME SS =>
   436               (*do not descend into stronger args (achieving termination)*)
   437               if exists (fn D => le D S orelse exists (le D) path) SS then
   438                 witn_types path ts S solved_failed
   439               else
   440                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   441                   if forall is_some ws then
   442                     let val w = (Type (t, map (#1 o the) ws), S)
   443                     in (SOME w, (w :: solved', failed')) end
   444                   else witn_types path ts S (solved', failed')
   445                 end
   446           | NONE => witn_types path ts S solved_failed);
   447 
   448   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   449 
   450 end;