src/Pure/sorts.ML
author blanchet
Wed Mar 04 11:05:29 2009 +0100 (2009-03-04)
changeset 30242 aea5d7fa7ef5
parent 30240 5b25fee0362c
parent 30065 c9a1e0f7621b
child 31946 99ac0321cd47
permissions -rw-r--r--
Merge.
     1 (*  Title:      Pure/sorts.ML
     2     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     3 
     4 The order-sorted algebra of type classes.
     5 
     6 Classes denote (possibly empty) collections of types that are
     7 partially ordered by class inclusion. They are represented
     8 symbolically by strings.
     9 
    10 Sorts are intersections of finitely many classes. They are represented
    11 by lists of classes.  Normal forms of sorts are sorted lists of
    12 minimal classes (wrt. current class inclusion).
    13 *)
    14 
    15 signature SORTS =
    16 sig
    17   val make: sort list -> sort OrdList.T
    18   val subset: sort OrdList.T * sort OrdList.T -> bool
    19   val union: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    20   val subtract: sort OrdList.T -> sort OrdList.T -> sort OrdList.T
    21   val remove_sort: sort -> sort OrdList.T -> sort OrdList.T
    22   val insert_sort: sort -> sort OrdList.T -> sort OrdList.T
    23   val insert_typ: typ -> sort OrdList.T -> sort OrdList.T
    24   val insert_typs: typ list -> sort OrdList.T -> sort OrdList.T
    25   val insert_term: term -> sort OrdList.T -> sort OrdList.T
    26   val insert_terms: term list -> sort OrdList.T -> sort OrdList.T
    27   type algebra
    28   val rep_algebra: algebra ->
    29    {classes: serial Graph.T,
    30     arities: (class * (class * sort list)) list Symtab.table}
    31   val all_classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val minimize_sort: algebra -> sort -> sort
    40   val complete_sort: algebra -> sort -> sort
    41   val minimal_sorts: algebra -> sort list -> sort OrdList.T
    42   val certify_class: algebra -> class -> class    (*exception TYPE*)
    43   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    44   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    45   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    46   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    47   val empty_algebra: algebra
    48   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    49   val subalgebra: Pretty.pp -> (class -> bool) -> (class * string -> sort list option)
    50     -> algebra -> (sort -> sort) * algebra
    51   type class_error
    52   val class_error: Pretty.pp -> class_error -> string
    53   exception CLASS_ERROR of class_error
    54   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    55   val meet_sort: algebra -> typ * sort
    56     -> sort Vartab.table -> sort Vartab.table   (*exception CLASS_ERROR*)
    57   val meet_sort_typ: algebra -> typ * sort -> typ -> typ   (*exception CLASS_ERROR*)
    58   val of_sort: algebra -> typ * sort -> bool
    59   val weaken: algebra -> ('a * class -> class -> 'a) -> 'a * class -> class -> 'a
    60   val of_sort_derivation: Pretty.pp -> algebra ->
    61     {class_relation: 'a * class -> class -> 'a,
    62      type_constructor: string -> ('a * class) list list -> class -> 'a,
    63      type_variable: typ -> ('a * class) list} ->
    64     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    65   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    66 end;
    67 
    68 structure Sorts: SORTS =
    69 struct
    70 
    71 
    72 (** ordered lists of sorts **)
    73 
    74 val make = OrdList.make TermOrd.sort_ord;
    75 val op subset = OrdList.subset TermOrd.sort_ord;
    76 val op union = OrdList.union TermOrd.sort_ord;
    77 val subtract = OrdList.subtract TermOrd.sort_ord;
    78 
    79 val remove_sort = OrdList.remove TermOrd.sort_ord;
    80 val insert_sort = OrdList.insert TermOrd.sort_ord;
    81 
    82 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    83   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    84   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    85 and insert_typs [] Ss = Ss
    86   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    87 
    88 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    89   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    90   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    91   | insert_term (Bound _) Ss = Ss
    92   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    93   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    94 
    95 fun insert_terms [] Ss = Ss
    96   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    97 
    98 
    99 
   100 (** order-sorted algebra **)
   101 
   102 (*
   103   classes: graph representing class declarations together with proper
   104     subclass relation, which needs to be transitive and acyclic.
   105 
   106   arities: table of association lists of all type arities; (t, ars)
   107     means that type constructor t has the arities ars; an element
   108     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   109     via c0 <= c.  "Coregularity" of the arities structure requires
   110     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   111     c1 <= c2 holds Ss1 <= Ss2.
   112 *)
   113 
   114 datatype algebra = Algebra of
   115  {classes: serial Graph.T,
   116   arities: (class * (class * sort list)) list Symtab.table};
   117 
   118 fun rep_algebra (Algebra args) = args;
   119 
   120 val classes_of = #classes o rep_algebra;
   121 val arities_of = #arities o rep_algebra;
   122 
   123 fun make_algebra (classes, arities) =
   124   Algebra {classes = classes, arities = arities};
   125 
   126 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   127 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   128 
   129 
   130 (* classes *)
   131 
   132 fun all_classes (Algebra {classes, ...}) = Graph.all_preds classes (Graph.maximals classes);
   133 
   134 val super_classes = Graph.imm_succs o classes_of;
   135 
   136 
   137 (* class relations *)
   138 
   139 val class_less = Graph.is_edge o classes_of;
   140 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   141 
   142 
   143 (* sort relations *)
   144 
   145 fun sort_le algebra (S1, S2) =
   146   S1 = S2 orelse forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   147 
   148 fun sorts_le algebra (Ss1, Ss2) =
   149   ListPair.all (sort_le algebra) (Ss1, Ss2);
   150 
   151 fun sort_eq algebra (S1, S2) =
   152   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   153 
   154 
   155 (* intersection *)
   156 
   157 fun inter_class algebra c S =
   158   let
   159     fun intr [] = [c]
   160       | intr (S' as c' :: c's) =
   161           if class_le algebra (c', c) then S'
   162           else if class_le algebra (c, c') then intr c's
   163           else c' :: intr c's
   164   in intr S end;
   165 
   166 fun inter_sort algebra (S1, S2) =
   167   sort_strings (fold (inter_class algebra) S1 S2);
   168 
   169 
   170 (* normal forms *)
   171 
   172 fun minimize_sort _ [] = []
   173   | minimize_sort _ (S as [_]) = S
   174   | minimize_sort algebra S =
   175       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   176       |> sort_distinct string_ord;
   177 
   178 fun complete_sort algebra =
   179   Graph.all_succs (classes_of algebra) o minimize_sort algebra;
   180 
   181 fun minimal_sorts algebra raw_sorts =
   182   let
   183     fun le S1 S2 = sort_le algebra (S1, S2);
   184     val sorts = make (map (minimize_sort algebra) raw_sorts);
   185   in sorts |> filter_out (fn S => exists (fn S' => le S' S andalso not (le S S')) sorts) end;
   186 
   187 
   188 (* certify *)
   189 
   190 fun certify_class algebra c =
   191   if can (Graph.get_node (classes_of algebra)) c then c
   192   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   193 
   194 fun certify_sort classes = minimize_sort classes o map (certify_class classes);
   195 
   196 
   197 
   198 (** build algebras **)
   199 
   200 (* classes *)
   201 
   202 fun err_dup_class c = error ("Duplicate declaration of class: " ^ quote c);
   203 
   204 fun err_cyclic_classes pp css =
   205   error (cat_lines (map (fn cs =>
   206     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   207 
   208 fun add_class pp (c, cs) = map_classes (fn classes =>
   209   let
   210     val classes' = classes |> Graph.new_node (c, serial ())
   211       handle Graph.DUP dup => err_dup_class dup;
   212     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   213       handle Graph.CYCLES css => err_cyclic_classes pp css;
   214   in classes'' end);
   215 
   216 
   217 (* arities *)
   218 
   219 local
   220 
   221 fun for_classes _ NONE = ""
   222   | for_classes pp (SOME (c1, c2)) =
   223       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   224 
   225 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   226   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   227     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   228     Pretty.string_of_arity pp (t, Ss', [c']));
   229 
   230 fun coregular pp algebra t (c, (c0, Ss)) ars =
   231   let
   232     fun conflict (c', (_, Ss')) =
   233       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   234         SOME ((c, c'), (c', Ss'))
   235       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   236         SOME ((c', c), (c', Ss'))
   237       else NONE;
   238   in
   239     (case get_first conflict ars of
   240       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   241     | NONE => (c, (c0, Ss)) :: ars)
   242   end;
   243 
   244 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   245 
   246 fun insert pp algebra t (c, (c0, Ss)) ars =
   247   (case AList.lookup (op =) ars c of
   248     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   249   | SOME (_, Ss') =>
   250       if sorts_le algebra (Ss, Ss') then ars
   251       else if sorts_le algebra (Ss', Ss) then
   252         coregular pp algebra t (c, (c0, Ss))
   253           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   254       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   255 
   256 fun insert_ars pp algebra (t, ars) arities =
   257   let val ars' =
   258     Symtab.lookup_list arities t
   259     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   260   in Symtab.update (t, ars') arities end;
   261 
   262 in
   263 
   264 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   265 
   266 fun add_arities_table pp algebra =
   267   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   268 
   269 end;
   270 
   271 
   272 (* classrel *)
   273 
   274 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   275   Symtab.empty
   276   |> add_arities_table pp algebra arities);
   277 
   278 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   279   classes |> Graph.add_edge_trans_acyclic rel
   280     handle Graph.CYCLES css => err_cyclic_classes pp css);
   281 
   282 
   283 (* empty and merge *)
   284 
   285 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   286 
   287 fun merge_algebra pp
   288    (Algebra {classes = classes1, arities = arities1},
   289     Algebra {classes = classes2, arities = arities2}) =
   290   let
   291     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   292       handle Graph.DUP c => err_dup_class c
   293           | Graph.CYCLES css => err_cyclic_classes pp css;
   294     val algebra0 = make_algebra (classes', Symtab.empty);
   295     val arities' = Symtab.empty
   296       |> add_arities_table pp algebra0 arities1
   297       |> add_arities_table pp algebra0 arities2;
   298   in make_algebra (classes', arities') end;
   299 
   300 
   301 (* algebra projections *)
   302 
   303 fun subalgebra pp P sargs (algebra as Algebra {classes, arities}) =
   304   let
   305     val restrict_sort = minimize_sort algebra o filter P o Graph.all_succs classes;
   306     fun restrict_arity tyco (c, (_, Ss)) =
   307       if P c then case sargs (c, tyco)
   308        of SOME sorts => SOME (c, (c, Ss |> map2 (curry (inter_sort algebra)) sorts
   309           |> map restrict_sort))
   310         | NONE => NONE
   311       else NONE;
   312     val classes' = classes |> Graph.subgraph P;
   313     val arities' = arities |> Symtab.map' (map_filter o restrict_arity);
   314   in (restrict_sort, rebuild_arities pp (make_algebra (classes', arities'))) end;
   315 
   316 
   317 
   318 (** sorts of types **)
   319 
   320 (* errors -- delayed message composition *)
   321 
   322 datatype class_error =
   323   NoClassrel of class * class |
   324   NoArity of string * class |
   325   NoSubsort of sort * sort;
   326 
   327 fun class_error pp (NoClassrel (c1, c2)) =
   328       "No class relation " ^ Pretty.string_of_classrel pp [c1, c2]
   329   | class_error pp (NoArity (a, c)) =
   330       "No type arity " ^ Pretty.string_of_arity pp (a, [], [c])
   331   | class_error pp (NoSubsort (S1, S2)) =
   332      "Cannot derive subsort relation " ^ Pretty.string_of_sort pp S1
   333        ^ " < " ^ Pretty.string_of_sort pp S2;
   334 
   335 exception CLASS_ERROR of class_error;
   336 
   337 
   338 (* mg_domain *)
   339 
   340 fun mg_domain algebra a S =
   341   let
   342     val arities = arities_of algebra;
   343     fun dom c =
   344       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   345         NONE => raise CLASS_ERROR (NoArity (a, c))
   346       | SOME (_, Ss) => Ss);
   347     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   348   in
   349     (case S of
   350       [] => raise Fail "Unknown domain of empty intersection"
   351     | c :: cs => fold dom_inter cs (dom c))
   352   end;
   353 
   354 
   355 (* meet_sort *)
   356 
   357 fun meet_sort algebra =
   358   let
   359     fun inters S S' = inter_sort algebra (S, S');
   360     fun meet _ [] = I
   361       | meet (TFree (_, S)) S' =
   362           if sort_le algebra (S, S') then I
   363           else raise CLASS_ERROR (NoSubsort (S, S'))
   364       | meet (TVar (v, S)) S' =
   365           if sort_le algebra (S, S') then I
   366           else Vartab.map_default (v, S) (inters S')
   367       | meet (Type (a, Ts)) S = fold2 meet Ts (mg_domain algebra a S);
   368   in uncurry meet end;
   369 
   370 fun meet_sort_typ algebra (T, S) =
   371   let
   372     val tab = meet_sort algebra (T, S) Vartab.empty;
   373   in Term.map_type_tvar (fn (v, _) =>
   374     TVar (v, (the o Vartab.lookup tab) v))
   375   end;
   376 
   377 
   378 (* of_sort *)
   379 
   380 fun of_sort algebra =
   381   let
   382     fun ofS (_, []) = true
   383       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   384       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   385       | ofS (Type (a, Ts), S) =
   386           let val Ss = mg_domain algebra a S in
   387             ListPair.all ofS (Ts, Ss)
   388           end handle CLASS_ERROR _ => false;
   389   in ofS end;
   390 
   391 
   392 (* animating derivations *)
   393 
   394 fun weaken algebra class_relation =
   395   let
   396     fun path (x, c1 :: c2 :: cs) = path (class_relation (x, c1) c2, c2 :: cs)
   397       | path (x, _) = x;
   398   in fn (x, c1) => fn c2 =>
   399     (case Graph.irreducible_paths (classes_of algebra) (c1, c2) of
   400       [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   401     | cs :: _ => path (x, cs))
   402   end;
   403 
   404 fun of_sort_derivation pp algebra {class_relation, type_constructor, type_variable} =
   405   let
   406     val weaken = weaken algebra class_relation;
   407     val arities = arities_of algebra;
   408 
   409     fun weakens S1 S2 = S2 |> map (fn c2 =>
   410       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   411         SOME d1 => weaken d1 c2
   412       | NONE => raise CLASS_ERROR (NoSubsort (map #2 S1, S2))));
   413 
   414     fun derive _ [] = []
   415       | derive (Type (a, Ts)) S =
   416           let
   417             val Ss = mg_domain algebra a S;
   418             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   419           in
   420             S |> map (fn c =>
   421               let
   422                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   423                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   424               in weaken (type_constructor a dom' c0, c0) c end)
   425           end
   426       | derive T S = weakens (type_variable T) S;
   427   in uncurry derive end;
   428 
   429 
   430 (* witness_sorts *)
   431 
   432 fun witness_sorts algebra types hyps sorts =
   433   let
   434     fun le S1 S2 = sort_le algebra (S1, S2);
   435     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   436     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   437     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   438 
   439     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   440       | witn_sort path S (solved, failed) =
   441           if exists (le S) failed then (NONE, (solved, failed))
   442           else
   443             (case get_first (get_solved S) solved of
   444               SOME w => (SOME w, (solved, failed))
   445             | NONE =>
   446                 (case get_first (get_hyp S) hyps of
   447                   SOME w => (SOME w, (w :: solved, failed))
   448                 | NONE => witn_types path types S (solved, failed)))
   449 
   450     and witn_sorts path x = fold_map (witn_sort path) x
   451 
   452     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   453       | witn_types path (t :: ts) S solved_failed =
   454           (case mg_dom t S of
   455             SOME SS =>
   456               (*do not descend into stronger args (achieving termination)*)
   457               if exists (fn D => le D S orelse exists (le D) path) SS then
   458                 witn_types path ts S solved_failed
   459               else
   460                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   461                   if forall is_some ws then
   462                     let val w = (Type (t, map (#1 o the) ws), S)
   463                     in (SOME w, (w :: solved', failed')) end
   464                   else witn_types path ts S (solved', failed')
   465                 end
   466           | NONE => witn_types path ts S solved_failed);
   467 
   468   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   469 
   470 end;