src/Pure/sorts.ML
author haftmann
Fri Sep 01 08:36:53 2006 +0200 (2006-09-01)
changeset 20454 7e948db7e42d
parent 20391 d079804d3b82
child 20465 95f6d354b0ed
permissions -rw-r--r--
project_algebra yields sort projector
     1 (*  Title:      Pure/sorts.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Stefan Berghofer, TU Muenchen
     4 
     5 The order-sorted algebra of type classes.
     6 
     7 Classes denote (possibly empty) collections of types that are
     8 partially ordered by class inclusion. They are represented
     9 symbolically by strings.
    10 
    11 Sorts are intersections of finitely many classes. They are represented
    12 by lists of classes.  Normal forms of sorts are sorted lists of
    13 minimal classes (wrt. current class inclusion).
    14 *)
    15 
    16 signature SORTS =
    17 sig
    18   val eq_set: sort list * sort list -> bool
    19   val union: sort list -> sort list -> sort list
    20   val subtract: sort list -> sort list -> sort list
    21   val remove_sort: sort -> sort list -> sort list
    22   val insert_sort: sort -> sort list -> sort list
    23   val insert_typ: typ -> sort list -> sort list
    24   val insert_typs: typ list -> sort list -> sort list
    25   val insert_term: term -> sort list -> sort list
    26   val insert_terms: term list -> sort list -> sort list
    27   type algebra
    28   val rep_algebra: algebra ->
    29    {classes: stamp Graph.T,
    30     arities: (class * (class * sort list)) list Symtab.table}
    31   val classes: algebra -> class list
    32   val super_classes: algebra -> class -> class list
    33   val class_less: algebra -> class * class -> bool
    34   val class_le: algebra -> class * class -> bool
    35   val sort_eq: algebra -> sort * sort -> bool
    36   val sort_le: algebra -> sort * sort -> bool
    37   val sorts_le: algebra -> sort list * sort list -> bool
    38   val inter_sort: algebra -> sort * sort -> sort
    39   val certify_class: algebra -> class -> class    (*exception TYPE*)
    40   val certify_sort: algebra -> sort -> sort       (*exception TYPE*)
    41   val add_class: Pretty.pp -> class * class list -> algebra -> algebra
    42   val add_classrel: Pretty.pp -> class * class -> algebra -> algebra
    43   val add_arities: Pretty.pp -> string * (class * sort list) list -> algebra -> algebra
    44   val empty_algebra: algebra
    45   val merge_algebra: Pretty.pp -> algebra * algebra -> algebra
    46   val project_algebra: Pretty.pp -> (class -> bool) -> algebra -> algebra * (sort -> sort)
    47   type class_error
    48   val class_error: Pretty.pp -> class_error -> 'a
    49   exception CLASS_ERROR of class_error
    50   val mg_domain: algebra -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    51   val of_sort: algebra -> typ * sort -> bool
    52   val of_sort_derivation: Pretty.pp -> algebra ->
    53     {classrel: 'a * class -> class -> 'a,
    54      constructor: string -> ('a * class) list list -> class -> 'a,
    55      variable: typ -> ('a * class) list} ->
    56     typ * sort -> 'a list   (*exception CLASS_ERROR*)
    57   val witness_sorts: algebra -> string list -> sort list -> sort list -> (typ * sort) list
    58 end;
    59 
    60 structure Sorts (*: SORTS *) =
    61 struct
    62 
    63 
    64 (** ordered lists of sorts **)
    65 
    66 val eq_set = OrdList.eq_set Term.sort_ord;
    67 val op union = OrdList.union Term.sort_ord;
    68 val subtract = OrdList.subtract Term.sort_ord;
    69 
    70 val remove_sort = OrdList.remove Term.sort_ord;
    71 val insert_sort = OrdList.insert Term.sort_ord;
    72 
    73 fun insert_typ (TFree (_, S)) Ss = insert_sort S Ss
    74   | insert_typ (TVar (_, S)) Ss = insert_sort S Ss
    75   | insert_typ (Type (_, Ts)) Ss = insert_typs Ts Ss
    76 and insert_typs [] Ss = Ss
    77   | insert_typs (T :: Ts) Ss = insert_typs Ts (insert_typ T Ss);
    78 
    79 fun insert_term (Const (_, T)) Ss = insert_typ T Ss
    80   | insert_term (Free (_, T)) Ss = insert_typ T Ss
    81   | insert_term (Var (_, T)) Ss = insert_typ T Ss
    82   | insert_term (Bound _) Ss = Ss
    83   | insert_term (Abs (_, T, t)) Ss = insert_term t (insert_typ T Ss)
    84   | insert_term (t $ u) Ss = insert_term t (insert_term u Ss);
    85 
    86 fun insert_terms [] Ss = Ss
    87   | insert_terms (t :: ts) Ss = insert_terms ts (insert_term t Ss);
    88 
    89 
    90 
    91 (** order-sorted algebra **)
    92 
    93 (*
    94   classes: graph representing class declarations together with proper
    95     subclass relation, which needs to be transitive and acyclic.
    96 
    97   arities: table of association lists of all type arities; (t, ars)
    98     means that type constructor t has the arities ars; an element
    99     (c, (c0, Ss)) of ars represents the arity t::(Ss)c being derived
   100     via c0 <= c.  "Coregularity" of the arities structure requires
   101     that for any two declarations t::(Ss1)c1 and t::(Ss2)c2 such that
   102     c1 <= c2 holds Ss1 <= Ss2.
   103 *)
   104 
   105 datatype algebra = Algebra of
   106  {classes: stamp Graph.T,
   107   arities: (class * (class * sort list)) list Symtab.table};
   108 
   109 fun rep_algebra (Algebra args) = args;
   110 
   111 val classes_of = #classes o rep_algebra;
   112 val arities_of = #arities o rep_algebra;
   113 
   114 fun make_algebra (classes, arities) =
   115   Algebra {classes = classes, arities = arities};
   116 
   117 fun map_classes f (Algebra {classes, arities}) = make_algebra (f classes, arities);
   118 fun map_arities f (Algebra {classes, arities}) = make_algebra (classes, f arities);
   119 
   120 
   121 (* classes *)
   122 
   123 val classes = Graph.keys o classes_of;
   124 val super_classes = Graph.imm_succs o classes_of;
   125 
   126 
   127 (* class relations *)
   128 
   129 val class_less = Graph.is_edge o classes_of;
   130 fun class_le algebra (c1, c2) = c1 = c2 orelse class_less algebra (c1, c2);
   131 
   132 
   133 (* sort relations *)
   134 
   135 fun sort_le algebra (S1, S2) =
   136   forall (fn c2 => exists (fn c1 => class_le algebra (c1, c2)) S1) S2;
   137 
   138 fun sorts_le algebra (Ss1, Ss2) =
   139   ListPair.all (sort_le algebra) (Ss1, Ss2);
   140 
   141 fun sort_eq algebra (S1, S2) =
   142   sort_le algebra (S1, S2) andalso sort_le algebra (S2, S1);
   143 
   144 
   145 (* intersection *)
   146 
   147 fun inter_class algebra c S =
   148   let
   149     fun intr [] = [c]
   150       | intr (S' as c' :: c's) =
   151           if class_le algebra (c', c) then S'
   152           else if class_le algebra (c, c') then intr c's
   153           else c' :: intr c's
   154   in intr S end;
   155 
   156 fun inter_sort algebra (S1, S2) =
   157   sort_strings (fold (inter_class algebra) S1 S2);
   158 
   159 
   160 (* normal form *)
   161 
   162 fun norm_sort _ [] = []
   163   | norm_sort _ (S as [_]) = S
   164   | norm_sort algebra S =
   165       filter (fn c => not (exists (fn c' => class_less algebra (c', c)) S)) S
   166       |> sort_distinct string_ord;
   167 
   168 
   169 (* certify *)
   170 
   171 fun certify_class algebra c =
   172   if can (Graph.get_node (classes_of algebra)) c then c
   173   else raise TYPE ("Undeclared class: " ^ quote c, [], []);
   174 
   175 fun certify_sort classes = norm_sort classes o map (certify_class classes);
   176 
   177 
   178 
   179 (** build algebras **)
   180 
   181 (* classes *)
   182 
   183 fun err_dup_classes cs =
   184   error ("Duplicate declaration of class(es): " ^ commas_quote cs);
   185 
   186 fun err_cyclic_classes pp css =
   187   error (cat_lines (map (fn cs =>
   188     "Cycle in class relation: " ^ Pretty.string_of_classrel pp cs) css));
   189 
   190 fun add_class pp (c, cs) = map_classes (fn classes =>
   191   let
   192     val classes' = classes |> Graph.new_node (c, stamp ())
   193       handle Graph.DUP dup => err_dup_classes [dup];
   194     val classes'' = classes' |> fold Graph.add_edge_trans_acyclic (map (pair c) cs)
   195       handle Graph.CYCLES css => err_cyclic_classes pp css;
   196   in classes'' end);
   197 
   198 
   199 (* arities *)
   200 
   201 local
   202 
   203 fun for_classes _ NONE = ""
   204   | for_classes pp (SOME (c1, c2)) =
   205       " for classes " ^ Pretty.string_of_classrel pp [c1, c2];
   206 
   207 fun err_conflict pp t cc (c, Ss) (c', Ss') =
   208   error ("Conflict of type arities" ^ for_classes pp cc ^ ":\n  " ^
   209     Pretty.string_of_arity pp (t, Ss, [c]) ^ " and\n  " ^
   210     Pretty.string_of_arity pp (t, Ss', [c']));
   211 
   212 fun coregular pp algebra t (c, (c0, Ss)) ars =
   213   let
   214     fun conflict (c', (_, Ss')) =
   215       if class_le algebra (c, c') andalso not (sorts_le algebra (Ss, Ss')) then
   216         SOME ((c, c'), (c', Ss'))
   217       else if class_le algebra (c', c) andalso not (sorts_le algebra (Ss', Ss)) then
   218         SOME ((c', c), (c', Ss'))
   219       else NONE;
   220   in
   221     (case get_first conflict ars of
   222       SOME ((c1, c2), (c', Ss')) => err_conflict pp t (SOME (c1, c2)) (c, Ss) (c', Ss')
   223     | NONE => (c, (c0, Ss)) :: ars)
   224   end;
   225 
   226 fun complete algebra (c0, Ss) = map (rpair (c0, Ss)) (c0 :: super_classes algebra c0);
   227 
   228 fun insert pp algebra t (c, (c0, Ss)) ars =
   229   (case AList.lookup (op =) ars c of
   230     NONE => coregular pp algebra t (c, (c0, Ss)) ars
   231   | SOME (_, Ss') =>
   232       if sorts_le algebra (Ss, Ss') then ars
   233       else if sorts_le algebra (Ss', Ss) then
   234         coregular pp algebra t (c, (c0, Ss))
   235           (filter_out (fn (c'', (_, Ss'')) => c = c'' andalso Ss'' = Ss') ars)
   236       else err_conflict pp t NONE (c, Ss) (c, Ss'));
   237 
   238 fun insert_ars pp algebra (t, ars) arities =
   239   let val ars' =
   240     Symtab.lookup_list arities t
   241     |> fold_rev (fold_rev (insert pp algebra t)) (map (complete algebra) ars)
   242   in Symtab.update (t, ars') arities end;
   243 
   244 in
   245 
   246 fun add_arities pp arg algebra = algebra |> map_arities (insert_ars pp algebra arg);
   247 
   248 fun add_arities_table pp algebra =
   249   Symtab.fold (fn (t, ars) => insert_ars pp algebra (t, map snd ars));
   250 
   251 end;
   252 
   253 
   254 (* classrel *)
   255 
   256 fun rebuild_arities pp algebra = algebra |> map_arities (fn arities =>
   257   Symtab.empty
   258   |> add_arities_table pp algebra arities);
   259 
   260 fun add_classrel pp rel = rebuild_arities pp o map_classes (fn classes =>
   261   classes |> Graph.add_edge_trans_acyclic rel
   262     handle Graph.CYCLES css => err_cyclic_classes pp css);
   263 
   264 
   265 (* empty and merge *)
   266 
   267 val empty_algebra = make_algebra (Graph.empty, Symtab.empty);
   268 
   269 fun merge_algebra pp
   270    (Algebra {classes = classes1, arities = arities1},
   271     Algebra {classes = classes2, arities = arities2}) =
   272   let
   273     val classes' = Graph.merge_trans_acyclic (op =) (classes1, classes2)
   274       handle Graph.DUPS cs => err_dup_classes cs
   275           | Graph.CYCLES css => err_cyclic_classes pp css;
   276     val algebra0 = make_algebra (classes', Symtab.empty);
   277     val arities' = Symtab.empty
   278       |> add_arities_table pp algebra0 arities1
   279       |> add_arities_table pp algebra0 arities2;
   280   in make_algebra (classes', arities') end;
   281 
   282 fun project_algebra pp proj (algebra as Algebra {classes, arities}) =
   283   let
   284     val classes' = classes |> Graph.project proj;
   285     val proj' = can (Graph.get_node classes');
   286     fun proj_classes class =
   287       if proj' class
   288       then [class]
   289       else (norm_sort algebra o maps proj_classes o super_classes algebra) class;
   290     val class_subst = fold (fn class => Symtab.update (class, proj_classes class))
   291       (Graph.keys classes) Symtab.empty;
   292     fun proj_arity (c, (_, Ss)) =
   293       if proj' c then SOME (c,
   294         (c, map (norm_sort algebra o filter proj' o Graph.all_succs classes) Ss)) else NONE;
   295     val arities' = arities |> (Symtab.map o map_filter) proj_arity;
   296     val algebra' = rebuild_arities pp (make_algebra (classes', arities'));
   297     val proj_sort = norm_sort algebra' o maps (the o Symtab.lookup class_subst);
   298   in (rebuild_arities pp (make_algebra (classes', arities')), proj_sort) end;
   299 
   300 
   301 (** sorts of types **)
   302 
   303 (* errors *)
   304 
   305 datatype class_error = NoClassrel of class * class | NoArity of string * class;
   306 
   307 fun class_error pp (NoClassrel (c1, c2)) =
   308       error ("No class relation " ^ Pretty.string_of_classrel pp [c1, c2])
   309   | class_error pp (NoArity (a, c)) =
   310       error ("No type arity " ^ Pretty.string_of_arity pp (a, [], [c]));
   311 
   312 exception CLASS_ERROR of class_error;
   313 
   314 
   315 (* mg_domain *)
   316 
   317 fun mg_domain algebra a S =
   318   let
   319     val arities = arities_of algebra;
   320     fun dom c =
   321       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   322         NONE => raise CLASS_ERROR (NoArity (a, c))
   323       | SOME (_, Ss) => Ss);
   324     fun dom_inter c Ss = ListPair.map (inter_sort algebra) (dom c, Ss);
   325   in
   326     (case S of
   327       [] => raise Fail "Unknown domain of empty intersection"
   328     | c :: cs => fold dom_inter cs (dom c))
   329   end;
   330 
   331 
   332 (* of_sort *)
   333 
   334 fun of_sort algebra =
   335   let
   336     fun ofS (_, []) = true
   337       | ofS (TFree (_, S), S') = sort_le algebra (S, S')
   338       | ofS (TVar (_, S), S') = sort_le algebra (S, S')
   339       | ofS (Type (a, Ts), S) =
   340           let val Ss = mg_domain algebra a S in
   341             ListPair.all ofS (Ts, Ss)
   342           end handle CLASS_ERROR _ => false;
   343   in ofS end;
   344 
   345 
   346 (* of_sort_derivation *)
   347 
   348 fun of_sort_derivation pp algebra {classrel, constructor, variable} =
   349   let
   350     val {classes, arities} = rep_algebra algebra;
   351     fun weaken_path (x, c1 :: c2 :: cs) =
   352           weaken_path (classrel (x, c1) c2, c2 :: cs)
   353       | weaken_path (x, _) = x;
   354     fun weaken (x, c1) c2 =
   355       (case Graph.irreducible_paths classes (c1, c2) of
   356         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
   357       | cs :: _ => weaken_path (x, cs));
   358 
   359     fun weakens S1 S2 = S2 |> map (fn c2 =>
   360       (case S1 |> find_first (fn (_, c1) => class_le algebra (c1, c2)) of
   361         SOME d1 => weaken d1 c2
   362       | NONE => error ("Cannot derive subsort relation " ^
   363           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   364 
   365     fun derive _ [] = []
   366       | derive (Type (a, Ts)) S =
   367           let
   368             val Ss = mg_domain algebra a S;
   369             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   370           in
   371             S |> map (fn c =>
   372               let
   373                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   374                 val dom' = map2 (fn d => fn S' => weakens d S' ~~ S') dom Ss';
   375               in weaken (constructor a dom' c0, c0) c end)
   376           end
   377       | derive T S = weakens (variable T) S;
   378   in uncurry derive end;
   379 
   380 
   381 (* witness_sorts *)
   382 
   383 fun witness_sorts algebra types hyps sorts =
   384   let
   385     fun le S1 S2 = sort_le algebra (S1, S2);
   386     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   387     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   388     fun mg_dom t S = SOME (mg_domain algebra t S) handle CLASS_ERROR _ => NONE;
   389 
   390     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   391       | witn_sort path S (solved, failed) =
   392           if exists (le S) failed then (NONE, (solved, failed))
   393           else
   394             (case get_first (get_solved S) solved of
   395               SOME w => (SOME w, (solved, failed))
   396             | NONE =>
   397                 (case get_first (get_hyp S) hyps of
   398                   SOME w => (SOME w, (w :: solved, failed))
   399                 | NONE => witn_types path types S (solved, failed)))
   400 
   401     and witn_sorts path x = fold_map (witn_sort path) x
   402 
   403     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   404       | witn_types path (t :: ts) S solved_failed =
   405           (case mg_dom t S of
   406             SOME SS =>
   407               (*do not descend into stronger args (achieving termination)*)
   408               if exists (fn D => le D S orelse exists (le D) path) SS then
   409                 witn_types path ts S solved_failed
   410               else
   411                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   412                   if forall is_some ws then
   413                     let val w = (Type (t, map (#1 o the) ws), S)
   414                     in (SOME w, (w :: solved', failed')) end
   415                   else witn_types path ts S (solved', failed')
   416                 end
   417           | NONE => witn_types path ts S solved_failed);
   418 
   419   in map_filter I (#1 (witn_sorts [] sorts ([], []))) end;
   420 
   421 end;