src/Pure/defs.ML
author wenzelm
Tue Jul 19 17:21:49 2005 +0200 (2005-07-19)
changeset 16877 e92cba1d4842
parent 16838 131ca99f6abf
child 16936 93772bd33871
permissions -rw-r--r--
tuned interfaces declare, define, finalize, merge:
canonical argument order, produce errors;
tuned checkT';
obua@16108
     1
(*  Title:      Pure/General/defs.ML
obua@16108
     2
    ID:         $Id$
obua@16108
     3
    Author:     Steven Obua, TU Muenchen
obua@16108
     4
wenzelm@16877
     5
    Checks if definitions preserve consistency of logic by enforcing
wenzelm@16877
     6
    that there are no cyclic definitions. The algorithm is described in
wenzelm@16877
     7
    "An Algorithm for Determining Definitional Cycles in Higher-Order Logic with Overloading",
obua@16308
     8
    Steven Obua, technical report, to be written :-)
obua@16108
     9
*)
obua@16108
    10
wenzelm@16877
    11
signature DEFS =
wenzelm@16877
    12
sig
obua@16198
    13
  type graph
wenzelm@16877
    14
  val empty: graph
wenzelm@16877
    15
  val declare: string * typ -> graph -> graph
wenzelm@16877
    16
  val define: Pretty.pp -> string * typ -> string -> (string * typ) list -> graph -> graph
wenzelm@16877
    17
  val finalize: string * typ -> graph -> graph
wenzelm@16877
    18
  val merge: Pretty.pp -> graph -> graph -> graph
obua@16108
    19
obua@16198
    20
  val finals : graph -> (typ list) Symtab.table
obua@16158
    21
obua@16308
    22
  (* If set to true then the exceptions CIRCULAR and INFINITE_CHAIN return the full
wenzelm@16877
    23
     chain of definitions that lead to the exception. In the beginning, chain_history
obua@16308
    24
     is initialized with the Isabelle environment variable DEFS_CHAIN_HISTORY. *)
obua@16308
    25
  val set_chain_history : bool -> unit
obua@16308
    26
  val chain_history : unit -> bool
obua@16308
    27
obua@16743
    28
  datatype overloadingstate = Open | Closed | Final
obua@16826
    29
obua@16743
    30
  val overloading_info : graph -> string -> (typ * (string*typ) list * overloadingstate) option
obua@16826
    31
  val fast_overloading_info : graph -> string -> (typ * int * overloadingstate) option
obua@16108
    32
end
obua@16108
    33
obua@16108
    34
structure Defs :> DEFS = struct
obua@16108
    35
obua@16108
    36
type tyenv = Type.tyenv
obua@16108
    37
type edgelabel = (int * typ * typ * (typ * string * string) list)
obua@16108
    38
obua@16743
    39
datatype overloadingstate = Open | Closed | Final
obua@16361
    40
obua@16108
    41
datatype node = Node of
obua@16308
    42
         typ  (* most general type of constant *)
obua@16361
    43
         * defnode Symtab.table
wenzelm@16877
    44
             (* a table of defnodes, each corresponding to 1 definition of the
obua@16308
    45
                constant for a particular type, indexed by axiom name *)
wenzelm@16877
    46
         * (unit Symtab.table) Symtab.table
wenzelm@16877
    47
             (* a table of all back referencing defnodes to this node,
obua@16308
    48
                indexed by node name of the defnodes *)
obua@16198
    49
         * typ list (* a list of all finalized types *)
obua@16743
    50
         * overloadingstate
wenzelm@16877
    51
obua@16198
    52
     and defnode = Defnode of
obua@16198
    53
         typ  (* type of the constant in this particular definition *)
obua@16308
    54
         * (edgelabel list) Symtab.table (* The edges, grouped by nodes. *)
obua@16108
    55
obua@16108
    56
fun getnode graph noderef = the (Symtab.lookup (graph, noderef))
obua@16361
    57
fun get_nodedefs (Node (_, defs, _, _, _)) = defs
obua@16361
    58
fun get_defnode (Node (_, defs, _, _, _)) defname = Symtab.lookup (defs, defname)
wenzelm@16877
    59
fun get_defnode' graph noderef defname =
obua@16308
    60
    Symtab.lookup (get_nodedefs (the (Symtab.lookup (graph, noderef))), defname)
obua@16108
    61
obua@16361
    62
fun table_size table = Symtab.foldl (fn (x, _) => x+1) (0, table)
wenzelm@16877
    63
obua@16308
    64
datatype graphaction = Declare of string * typ
wenzelm@16877
    65
                     | Define of string * typ * string * string * (string * typ) list
wenzelm@16877
    66
                     | Finalize of string * typ
obua@16108
    67
obua@16384
    68
type graph = int * (string Symtab.table) * (graphaction list) * (node Symtab.table)
wenzelm@16877
    69
wenzelm@16877
    70
val CHAIN_HISTORY =
obua@16308
    71
    let
wenzelm@16877
    72
      fun f c = if Char.isSpace c then "" else String.str (Char.toUpper c)
obua@16308
    73
      val env = String.translate f (getenv "DEFS_CHAIN_HISTORY")
obua@16308
    74
    in
obua@16308
    75
      ref (env = "ON" orelse env = "TRUE")
obua@16308
    76
    end
obua@16308
    77
obua@16308
    78
fun set_chain_history b = CHAIN_HISTORY := b
obua@16308
    79
fun chain_history () = !CHAIN_HISTORY
obua@16308
    80
obua@16384
    81
val empty = (0, Symtab.empty, [], Symtab.empty)
obua@16108
    82
obua@16108
    83
exception DEFS of string;
obua@16108
    84
exception CIRCULAR of (typ * string * string) list;
obua@16113
    85
exception INFINITE_CHAIN of (typ * string * string) list;
obua@16108
    86
exception CLASH of string * string * string;
obua@16158
    87
exception FINAL of string * typ;
obua@16108
    88
obua@16108
    89
fun def_err s = raise (DEFS s)
obua@16108
    90
wenzelm@16877
    91
fun no_forwards defs =
wenzelm@16877
    92
    Symtab.foldl
wenzelm@16877
    93
    (fn (closed, (_, Defnode (_, edges))) =>
obua@16361
    94
        if not closed then false else Symtab.is_empty edges)
obua@16361
    95
    (true, defs)
obua@16361
    96
wenzelm@16877
    97
fun checkT' (Type (a, Ts)) = Type (a, map checkT' Ts)
wenzelm@16877
    98
  | checkT' (TFree (a, _)) = TVar ((a, 0), [])        (* FIXME !? *)
wenzelm@16877
    99
  | checkT' (TVar ((a, 0), _)) = TVar ((a, 0), [])
wenzelm@16877
   100
  | checkT' (T as TVar _) = raise TYPE ("Illegal schematic type variable encountered", [T], []);
obua@16384
   101
wenzelm@16877
   102
val checkT = Term.compress_type o checkT';
obua@16384
   103
wenzelm@16877
   104
fun rename ty1 ty2 = Logic.incr_tvar ((maxidx_of_typ ty1)+1) ty2;
obua@16108
   105
obua@16108
   106
fun subst_incr_tvar inc t =
wenzelm@16877
   107
    if (inc > 0) then
obua@16198
   108
      let
wenzelm@16877
   109
        val tv = typ_tvars t
wenzelm@16877
   110
        val t' = Logic.incr_tvar inc t
wenzelm@16877
   111
        fun update_subst (((n,i), _), s) =
wenzelm@16877
   112
            Vartab.update (((n, i), ([], TVar ((n, i+inc), []))), s)
obua@16198
   113
      in
wenzelm@16877
   114
        (t',List.foldl update_subst Vartab.empty tv)
wenzelm@16877
   115
      end
obua@16108
   116
    else
obua@16198
   117
      (t, Vartab.empty)
wenzelm@16877
   118
obua@16108
   119
fun subst s ty = Envir.norm_type s ty
wenzelm@16877
   120
obua@16108
   121
fun subst_history s history = map (fn (ty, cn, dn) => (subst s ty, cn, dn)) history
wenzelm@16877
   122
obua@16108
   123
fun is_instance instance_ty general_ty =
obua@16108
   124
    Type.typ_instance Type.empty_tsig (instance_ty, general_ty)
wenzelm@16877
   125
obua@16108
   126
fun is_instance_r instance_ty general_ty =
obua@16108
   127
    is_instance instance_ty (rename instance_ty general_ty)
wenzelm@16877
   128
wenzelm@16877
   129
fun unify ty1 ty2 =
obua@16108
   130
    SOME (fst (Type.unify Type.empty_tsig (Vartab.empty, 0) (ty1, ty2)))
obua@16108
   131
    handle Type.TUNIFY => NONE
wenzelm@16877
   132
wenzelm@16877
   133
(*
wenzelm@16877
   134
   Unifies ty1 and ty2, renaming ty1 and ty2 so that they have greater indices than max and
wenzelm@16877
   135
   so that they are different. All indices in ty1 and ty2 are supposed to be less than or
obua@16308
   136
   equal to max.
wenzelm@16877
   137
   Returns SOME (max', s1, s2), so that s1(ty1) = s2(ty2) and max' is greater or equal than
obua@16308
   138
   all indices in s1, s2, ty1, ty2.
obua@16108
   139
*)
wenzelm@16877
   140
fun unify_r max ty1 ty2 =
obua@16108
   141
    let
obua@16308
   142
      val max = Int.max(max, 0)
obua@16198
   143
      val max1 = max (* >= maxidx_of_typ ty1 *)
obua@16198
   144
      val max2 = max (* >= maxidx_of_typ ty2 *)
obua@16198
   145
      val max = Int.max(max, Int.max (max1, max2))
obua@16308
   146
      val (ty1, s1) = subst_incr_tvar (max + 1) ty1
obua@16308
   147
      val (ty2, s2) = subst_incr_tvar (max + max1 + 2) ty2
wenzelm@16877
   148
      val max = max + max1 + max2 + 2
obua@16198
   149
      fun merge a b = Vartab.merge (fn _ => false) (a, b)
obua@16108
   150
    in
obua@16198
   151
      case unify ty1 ty2 of
wenzelm@16877
   152
        NONE => NONE
obua@16198
   153
      | SOME s => SOME (max, merge s1 s, merge s2 s)
obua@16108
   154
    end
wenzelm@16877
   155
obua@16108
   156
fun can_be_unified_r ty1 ty2 =
obua@16108
   157
    let
obua@16198
   158
      val ty2 = rename ty1 ty2
obua@16108
   159
    in
obua@16198
   160
      case unify ty1 ty2 of
wenzelm@16877
   161
        NONE => false
obua@16198
   162
      | _ => true
obua@16108
   163
    end
wenzelm@16877
   164
obua@16108
   165
fun can_be_unified ty1 ty2 =
obua@16108
   166
    case unify ty1 ty2 of
obua@16198
   167
      NONE => false
obua@16198
   168
    | _ => true
wenzelm@16877
   169
obua@16308
   170
fun normalize_edge_idx (edge as (maxidx, u1, v1, history)) =
obua@16308
   171
    if maxidx <= 1000000 then edge else
obua@16308
   172
    let
wenzelm@16877
   173
wenzelm@16877
   174
      fun idxlist idx extract_ty inject_ty (tab, max) ts =
wenzelm@16877
   175
          foldr
wenzelm@16877
   176
            (fn (e, ((tab, max), ts)) =>
obua@16308
   177
                let
obua@16308
   178
                  val ((tab, max), ty) = idx (tab, max) (extract_ty e)
obua@16308
   179
                  val e = inject_ty (ty, e)
obua@16308
   180
                in
obua@16308
   181
                  ((tab, max), e::ts)
obua@16308
   182
                end)
obua@16308
   183
            ((tab,max), []) ts
wenzelm@16877
   184
wenzelm@16877
   185
      fun idx (tab,max) (TVar ((a,i),_)) =
wenzelm@16877
   186
          (case Inttab.lookup (tab, i) of
obua@16308
   187
             SOME j => ((tab, max), TVar ((a,j),[]))
obua@16308
   188
           | NONE => ((Inttab.update ((i, max), tab), max+1), TVar ((a,max),[])))
wenzelm@16877
   189
        | idx (tab,max) (Type (t, ts)) =
wenzelm@16877
   190
          let
obua@16308
   191
            val ((tab, max), ts) = idxlist idx I fst (tab, max) ts
obua@16308
   192
          in
obua@16308
   193
            ((tab,max), Type (t, ts))
obua@16308
   194
          end
obua@16308
   195
        | idx (tab, max) ty = ((tab, max), ty)
wenzelm@16877
   196
obua@16308
   197
      val ((tab,max), u1) = idx (Inttab.empty, 0) u1
obua@16308
   198
      val ((tab,max), v1) = idx (tab, max) v1
wenzelm@16877
   199
      val ((tab,max), history) =
obua@16308
   200
          idxlist idx
wenzelm@16877
   201
            (fn (ty,_,_) => ty)
wenzelm@16877
   202
            (fn (ty, (_, s1, s2)) => (ty, s1, s2))
obua@16308
   203
            (tab, max) history
obua@16308
   204
    in
obua@16308
   205
      (max, u1, v1, history)
obua@16308
   206
    end
wenzelm@16877
   207
obua@16108
   208
fun compare_edges (e1 as (maxidx1, u1, v1, history1)) (e2 as (maxidx2, u2, v2, history2)) =
obua@16108
   209
    let
obua@16198
   210
      val t1 = u1 --> v1
wenzelm@16877
   211
      val t2 = Logic.incr_tvar (maxidx1+1) (u2 --> v2)
obua@16108
   212
    in
obua@16308
   213
      if (is_instance t1 t2) then
wenzelm@16877
   214
        (if is_instance t2 t1 then
wenzelm@16877
   215
           SOME (int_ord (length history2, length history1))
wenzelm@16877
   216
         else
wenzelm@16877
   217
           SOME LESS)
obua@16308
   218
      else if (is_instance t2 t1) then
wenzelm@16877
   219
        SOME GREATER
obua@16198
   220
      else
wenzelm@16877
   221
        NONE
obua@16108
   222
    end
obua@16308
   223
obua@16308
   224
fun merge_edges_1 (x, []) = [x]
wenzelm@16877
   225
  | merge_edges_1 (x, (y::ys)) =
obua@16108
   226
    (case compare_edges x y of
obua@16198
   227
       SOME LESS => (y::ys)
obua@16198
   228
     | SOME EQUAL => (y::ys)
obua@16198
   229
     | SOME GREATER => merge_edges_1 (x, ys)
obua@16198
   230
     | NONE => y::(merge_edges_1 (x, ys)))
wenzelm@16877
   231
obua@16108
   232
fun merge_edges xs ys = foldl merge_edges_1 xs ys
obua@16108
   233
obua@16384
   234
fun declare' (g as (cost, axmap, actions, graph)) (cty as (name, ty)) =
wenzelm@16877
   235
    (cost, axmap, (Declare cty)::actions,
obua@16361
   236
     Symtab.update_new ((name, Node (ty, Symtab.empty, Symtab.empty, [], Open)), graph))
wenzelm@16877
   237
    handle Symtab.DUP _ =>
obua@16361
   238
           let
obua@16361
   239
             val (Node (gty, _, _, _, _)) = the (Symtab.lookup(graph, name))
obua@16361
   240
           in
obua@16361
   241
             if is_instance_r ty gty andalso is_instance_r gty ty then
obua@16361
   242
               g
obua@16361
   243
             else
obua@16361
   244
               def_err "constant is already declared with different type"
obua@16361
   245
           end
obua@16361
   246
wenzelm@16877
   247
fun declare'' g (name, ty) = declare' g (name, checkT ty)
obua@16361
   248
obua@16384
   249
val axcounter = ref (IntInf.fromInt 0)
obua@16384
   250
fun newaxname axmap axname =
obua@16384
   251
    let
obua@16384
   252
      val c = !axcounter
obua@16384
   253
      val _ = axcounter := c+1
obua@16384
   254
      val axname' = axname^"_"^(IntInf.toString c)
obua@16384
   255
    in
obua@16384
   256
      (Symtab.update ((axname', axname), axmap), axname')
obua@16384
   257
    end
obua@16384
   258
wenzelm@16877
   259
fun translate_ex axmap x =
obua@16384
   260
    let
wenzelm@16877
   261
      fun translate (ty, nodename, axname) =
obua@16384
   262
          (ty, nodename, the (Symtab.lookup (axmap, axname)))
obua@16384
   263
    in
obua@16384
   264
      case x of
obua@16384
   265
        INFINITE_CHAIN chain => raise (INFINITE_CHAIN (map translate chain))
obua@16384
   266
      | CIRCULAR cycle => raise (CIRCULAR (map translate cycle))
obua@16384
   267
      | _ => raise x
obua@16384
   268
    end
obua@16384
   269
obua@16826
   270
fun define' (cost, axmap, actions, graph) (mainref, ty) axname orig_axname body =
obua@16108
   271
    let
wenzelm@16877
   272
      val mainnode  = (case Symtab.lookup (graph, mainref) of
wenzelm@16877
   273
                         NONE => def_err ("constant "^mainref^" is not declared")
wenzelm@16877
   274
                       | SOME n => n)
obua@16361
   275
      val (Node (gty, defs, backs, finals, _)) = mainnode
wenzelm@16877
   276
      val _ = (if is_instance_r ty gty then ()
obua@16308
   277
               else def_err "type of constant does not match declared type")
wenzelm@16877
   278
      fun check_def (s, Defnode (ty', _)) =
wenzelm@16877
   279
          (if can_be_unified_r ty ty' then
wenzelm@16877
   280
             raise (CLASH (mainref, axname, s))
wenzelm@16877
   281
           else if s = axname then
wenzelm@16877
   282
             def_err "name of axiom is already used for another definition of this constant"
wenzelm@16877
   283
           else false)
obua@16198
   284
      val _ = Symtab.exists check_def defs
wenzelm@16877
   285
      fun check_final finalty =
wenzelm@16877
   286
          (if can_be_unified_r finalty ty then
wenzelm@16877
   287
             raise (FINAL (mainref, finalty))
wenzelm@16877
   288
           else
wenzelm@16877
   289
             true)
obua@16198
   290
      val _ = forall check_final finals
wenzelm@16877
   291
wenzelm@16877
   292
      (* now we know that the only thing that can prevent acceptance of the definition
obua@16308
   293
         is a cyclic dependency *)
wenzelm@16877
   294
obua@16308
   295
      fun insert_edges edges (nodename, links) =
wenzelm@16877
   296
          (if links = [] then
obua@16308
   297
             edges
obua@16308
   298
           else
obua@16308
   299
             let
obua@16308
   300
               val links = map normalize_edge_idx links
obua@16308
   301
             in
wenzelm@16877
   302
               Symtab.update ((nodename,
wenzelm@16877
   303
                               case Symtab.lookup (edges, nodename) of
wenzelm@16877
   304
                                 NONE => links
wenzelm@16877
   305
                               | SOME links' => merge_edges links' links),
obua@16308
   306
                              edges)
obua@16308
   307
             end)
wenzelm@16877
   308
obua@16308
   309
      fun make_edges ((bodyn, bodyty), edges) =
wenzelm@16877
   310
          let
wenzelm@16877
   311
            val bnode =
wenzelm@16877
   312
                (case Symtab.lookup (graph, bodyn) of
wenzelm@16877
   313
                   NONE => def_err "body of constant definition references undeclared constant"
wenzelm@16877
   314
                 | SOME x => x)
wenzelm@16877
   315
            val (Node (general_btyp, bdefs, bbacks, bfinals, closed)) = bnode
wenzelm@16877
   316
          in
obua@16361
   317
            if closed = Final then edges else
wenzelm@16877
   318
            case unify_r 0 bodyty general_btyp of
wenzelm@16877
   319
              NONE => edges
wenzelm@16877
   320
            | SOME (maxidx, sigma1, sigma2) =>
wenzelm@16877
   321
              if exists (is_instance_r bodyty) bfinals then
obua@16308
   322
                edges
obua@16308
   323
              else
wenzelm@16877
   324
                let
wenzelm@16877
   325
                  fun insert_trans_edges ((step1, edges), (nodename, links)) =
obua@16308
   326
                      let
obua@16308
   327
                        val (maxidx1, alpha1, beta1, defname) = step1
wenzelm@16877
   328
                        fun connect (maxidx2, alpha2, beta2, history) =
wenzelm@16877
   329
                            case unify_r (Int.max (maxidx1, maxidx2)) beta1 alpha2 of
wenzelm@16877
   330
                              NONE => NONE
wenzelm@16877
   331
                            | SOME (max, sleft, sright) =>
wenzelm@16877
   332
                              SOME (max, subst sleft alpha1, subst sright beta2,
wenzelm@16877
   333
                                    if !CHAIN_HISTORY then
wenzelm@16877
   334
                                      ((subst sleft beta1, bodyn, defname)::
wenzelm@16877
   335
                                       (subst_history sright history))
wenzelm@16877
   336
                                    else [])
obua@16308
   337
                        val links' = List.mapPartial connect links
obua@16308
   338
                      in
obua@16308
   339
                        (step1, insert_edges edges (nodename, links'))
obua@16308
   340
                      end
wenzelm@16877
   341
obua@16308
   342
                  fun make_edges' ((swallowed, edges),
obua@16308
   343
                                   (def_name, Defnode (def_ty, def_edges))) =
wenzelm@16877
   344
                      if swallowed then
wenzelm@16877
   345
                        (swallowed, edges)
wenzelm@16877
   346
                      else
wenzelm@16877
   347
                        (case unify_r 0 bodyty def_ty of
wenzelm@16877
   348
                           NONE => (swallowed, edges)
wenzelm@16877
   349
                         | SOME (maxidx, sigma1, sigma2) =>
wenzelm@16877
   350
                           (is_instance_r bodyty def_ty,
wenzelm@16877
   351
                            snd (Symtab.foldl insert_trans_edges
obua@16308
   352
                              (((maxidx, subst sigma1 ty, subst sigma2 def_ty, def_name),
obua@16308
   353
                                edges), def_edges))))
wenzelm@16877
   354
                  val (swallowed, edges) = Symtab.foldl make_edges' ((false, edges), bdefs)
wenzelm@16877
   355
                in
wenzelm@16877
   356
                  if swallowed then
wenzelm@16877
   357
                    edges
wenzelm@16877
   358
                  else
wenzelm@16877
   359
                    insert_edges edges
obua@16308
   360
                    (bodyn, [(maxidx, subst sigma1 ty, subst sigma2 general_btyp,[])])
wenzelm@16877
   361
                end
wenzelm@16877
   362
          end
wenzelm@16877
   363
obua@16308
   364
      val edges = foldl make_edges Symtab.empty body
wenzelm@16877
   365
wenzelm@16877
   366
      (* We also have to add the backreferences that this new defnode induces. *)
obua@16308
   367
      fun install_backrefs (graph, (noderef, links)) =
obua@16308
   368
          if links <> [] then
obua@16308
   369
            let
obua@16361
   370
              val (Node (ty, defs, backs, finals, closed)) = getnode graph noderef
wenzelm@16877
   371
              val _ = if closed = Final then
wenzelm@16877
   372
                        sys_error ("install_backrefs: closed node cannot be updated")
obua@16361
   373
                      else ()
obua@16308
   374
              val defnames =
obua@16308
   375
                  (case Symtab.lookup (backs, mainref) of
obua@16308
   376
                     NONE => Symtab.empty
obua@16308
   377
                   | SOME s => s)
obua@16308
   378
              val defnames' = Symtab.update_new ((axname, ()), defnames)
obua@16308
   379
              val backs' = Symtab.update ((mainref,defnames'), backs)
obua@16308
   380
            in
obua@16361
   381
              Symtab.update ((noderef, Node (ty, defs, backs', finals, closed)), graph)
obua@16308
   382
            end
obua@16308
   383
          else
obua@16308
   384
            graph
wenzelm@16877
   385
obua@16198
   386
      val graph = Symtab.foldl install_backrefs (graph, edges)
wenzelm@16877
   387
obua@16361
   388
      val (Node (_, _, backs, _, closed)) = getnode graph mainref
wenzelm@16877
   389
      val closed =
wenzelm@16877
   390
          if closed = Final then sys_error "define: closed node"
obua@16361
   391
          else if closed = Open andalso is_instance_r gty ty then Closed else closed
obua@16361
   392
obua@16308
   393
      val thisDefnode = Defnode (ty, edges)
wenzelm@16877
   394
      val graph = Symtab.update ((mainref, Node (gty, Symtab.update_new
obua@16361
   395
        ((axname, thisDefnode), defs), backs, finals, closed)), graph)
wenzelm@16877
   396
wenzelm@16877
   397
      (* Now we have to check all backreferences to this node and inform them about
obua@16308
   398
         the new defnode. In this section we also check for circularity. *)
obua@16308
   399
      fun update_backrefs ((backs, graph), (noderef, defnames)) =
wenzelm@16877
   400
          let
wenzelm@16877
   401
            fun update_defs ((defnames, graph),(defname, _)) =
wenzelm@16877
   402
                let
wenzelm@16877
   403
                  val (Node (nodety, nodedefs, nodebacks, nodefinals, closed)) =
obua@16361
   404
                      getnode graph noderef
obua@16361
   405
                  val _ = if closed = Final then sys_error "update_defs: closed node" else ()
wenzelm@16877
   406
                  val (Defnode (def_ty, defnode_edges)) =
obua@16308
   407
                      the (Symtab.lookup (nodedefs, defname))
wenzelm@16877
   408
                  val edges = the (Symtab.lookup (defnode_edges, mainref))
obua@16361
   409
                  val refclosed = ref false
wenzelm@16877
   410
wenzelm@16877
   411
                  (* the type of thisDefnode is ty *)
wenzelm@16877
   412
                  fun update (e as (max, alpha, beta, history), (changed, edges)) =
wenzelm@16877
   413
                      case unify_r max beta ty of
wenzelm@16877
   414
                        NONE => (changed, e::edges)
wenzelm@16877
   415
                      | SOME (max', s_beta, s_ty) =>
wenzelm@16877
   416
                        let
wenzelm@16877
   417
                          val alpha' = subst s_beta alpha
wenzelm@16877
   418
                          val ty' = subst s_ty ty
wenzelm@16877
   419
                          val _ =
wenzelm@16877
   420
                              if noderef = mainref andalso defname = axname then
wenzelm@16877
   421
                                (case unify alpha' ty' of
wenzelm@16877
   422
                                   NONE =>
wenzelm@16877
   423
                                   if (is_instance_r ty' alpha') then
wenzelm@16877
   424
                                     raise (INFINITE_CHAIN (
wenzelm@16877
   425
                                            (alpha', mainref, axname)::
wenzelm@16877
   426
                                            (subst_history s_beta history)@
wenzelm@16877
   427
                                            [(ty', mainref, axname)]))
wenzelm@16877
   428
                                   else ()
wenzelm@16877
   429
                                 | SOME s =>
obua@16308
   430
                                   raise (CIRCULAR (
wenzelm@16877
   431
                                          (subst s alpha', mainref, axname)::
wenzelm@16877
   432
                                          (subst_history s (subst_history s_beta history))@
wenzelm@16877
   433
                                          [(subst s ty', mainref, axname)])))
wenzelm@16877
   434
                              else ()
wenzelm@16877
   435
                        in
wenzelm@16877
   436
                          if is_instance_r beta ty then
wenzelm@16877
   437
                            (true, edges)
wenzelm@16877
   438
                          else
wenzelm@16877
   439
                            (changed, e::edges)
wenzelm@16877
   440
                        end
wenzelm@16877
   441
obua@16308
   442
                  val (changed, edges') = foldl update (false, []) edges
wenzelm@16877
   443
                  val defnames' = if edges' = [] then
wenzelm@16877
   444
                                    defnames
wenzelm@16877
   445
                                  else
obua@16308
   446
                                    Symtab.update ((defname, ()), defnames)
obua@16308
   447
                in
obua@16308
   448
                  if changed then
obua@16308
   449
                    let
wenzelm@16877
   450
                      val defnode_edges' =
obua@16308
   451
                          if edges' = [] then
obua@16308
   452
                            Symtab.delete mainref defnode_edges
obua@16308
   453
                          else
obua@16308
   454
                            Symtab.update ((mainref, edges'), defnode_edges)
obua@16308
   455
                      val defnode' = Defnode (def_ty, defnode_edges')
obua@16308
   456
                      val nodedefs' = Symtab.update ((defname, defnode'), nodedefs)
obua@16361
   457
                      val closed = if closed = Closed andalso Symtab.is_empty defnode_edges'
wenzelm@16877
   458
                                      andalso no_forwards nodedefs'
obua@16361
   459
                                   then Final else closed
wenzelm@16877
   460
                      val graph' =
wenzelm@16877
   461
                          Symtab.update
wenzelm@16877
   462
                            ((noderef,
wenzelm@16877
   463
                              Node (nodety, nodedefs', nodebacks, nodefinals, closed)),graph)
obua@16308
   464
                    in
obua@16308
   465
                      (defnames', graph')
obua@16308
   466
                    end
obua@16308
   467
                  else
obua@16308
   468
                    (defnames', graph)
obua@16308
   469
                end
wenzelm@16877
   470
wenzelm@16877
   471
            val (defnames', graph') = Symtab.foldl update_defs
obua@16308
   472
                                                   ((Symtab.empty, graph), defnames)
wenzelm@16877
   473
          in
wenzelm@16877
   474
            if Symtab.is_empty defnames' then
wenzelm@16877
   475
              (backs, graph')
wenzelm@16877
   476
            else
wenzelm@16877
   477
              let
wenzelm@16877
   478
                val backs' = Symtab.update_new ((noderef, defnames'), backs)
wenzelm@16877
   479
              in
wenzelm@16877
   480
                (backs', graph')
wenzelm@16877
   481
              end
wenzelm@16877
   482
          end
wenzelm@16877
   483
obua@16308
   484
      val (backs, graph) = Symtab.foldl update_backrefs ((Symtab.empty, graph), backs)
wenzelm@16877
   485
obua@16198
   486
      (* If a Circular exception is thrown then we never reach this point. *)
obua@16361
   487
      val (Node (gty, defs, _, finals, closed)) = getnode graph mainref
obua@16361
   488
      val closed = if closed = Closed andalso no_forwards defs then Final else closed
wenzelm@16877
   489
      val graph = Symtab.update ((mainref, Node (gty, defs, backs, finals, closed)), graph)
obua@16826
   490
      val actions' = (Define (mainref, ty, axname, orig_axname, body))::actions
wenzelm@16877
   491
    in
obua@16384
   492
      (cost+3, axmap, actions', graph)
obua@16384
   493
    end handle ex => translate_ex axmap ex
wenzelm@16877
   494
wenzelm@16877
   495
fun define'' (g as (cost, axmap, actions, graph)) (mainref, ty) orig_axname body =
obua@16308
   496
    let
obua@16308
   497
      val ty = checkT ty
wenzelm@16877
   498
      fun checkbody (n, t) =
wenzelm@16877
   499
          let
obua@16361
   500
            val (Node (_, _, _,_, closed)) = getnode graph n
obua@16361
   501
          in
obua@16361
   502
            case closed of
obua@16361
   503
              Final => NONE
obua@16361
   504
            | _ => SOME (n, checkT t)
obua@16361
   505
          end
obua@16361
   506
      val body = distinct (List.mapPartial checkbody body)
obua@16826
   507
      val (axmap, axname) = newaxname axmap orig_axname
obua@16308
   508
    in
obua@16826
   509
      define' (cost, axmap, actions, graph) (mainref, ty) axname orig_axname body
obua@16308
   510
    end
obua@16308
   511
wenzelm@16877
   512
fun finalize' (cost, axmap, history, graph) (noderef, ty) =
wenzelm@16877
   513
    case Symtab.lookup (graph, noderef) of
obua@16308
   514
      NONE => def_err ("cannot finalize constant "^noderef^"; it is not declared")
obua@16361
   515
    | SOME (Node (nodety, defs, backs, finals, closed)) =>
wenzelm@16877
   516
      let
wenzelm@16877
   517
        val _ =
obua@16308
   518
            if (not (is_instance_r ty nodety)) then
wenzelm@16877
   519
              def_err ("only type instances of the declared constant "^
obua@16308
   520
                       noderef^" can be finalized")
wenzelm@16877
   521
            else ()
wenzelm@16877
   522
        val _ = Symtab.exists
wenzelm@16877
   523
                  (fn (def_name, Defnode (def_ty, _)) =>
wenzelm@16877
   524
                      if can_be_unified_r ty def_ty then
wenzelm@16877
   525
                        def_err ("cannot finalize constant "^noderef^
obua@16308
   526
                                 "; clash with definition "^def_name)
wenzelm@16877
   527
                      else
wenzelm@16877
   528
                        false)
wenzelm@16877
   529
                  defs
wenzelm@16877
   530
obua@16198
   531
        fun update_finals [] = SOME [ty]
wenzelm@16877
   532
          | update_finals (final_ty::finals) =
obua@16198
   533
            (if is_instance_r ty final_ty then NONE
obua@16198
   534
             else
obua@16198
   535
               case update_finals finals of
obua@16198
   536
                 NONE => NONE
obua@16198
   537
               | (r as SOME finals) =>
obua@16198
   538
                 if (is_instance_r final_ty ty) then
obua@16198
   539
                   r
obua@16198
   540
                 else
wenzelm@16877
   541
                   SOME (final_ty :: finals))
wenzelm@16877
   542
      in
obua@16198
   543
        case update_finals finals of
obua@16384
   544
          NONE => (cost, axmap, history, graph)
wenzelm@16877
   545
        | SOME finals =>
wenzelm@16877
   546
          let
wenzelm@16877
   547
            val closed = if closed = Open andalso is_instance_r nodety ty then
wenzelm@16877
   548
                           Closed else
obua@16361
   549
                         closed
wenzelm@16877
   550
            val graph = Symtab.update ((noderef, Node(nodety, defs, backs, finals, closed)),
obua@16308
   551
                                       graph)
wenzelm@16877
   552
wenzelm@16877
   553
            fun update_backref ((graph, backs), (backrefname, backdefnames)) =
wenzelm@16877
   554
                let
wenzelm@16877
   555
                  fun update_backdef ((graph, defnames), (backdefname, _)) =
wenzelm@16877
   556
                      let
wenzelm@16877
   557
                        val (backnode as Node (backty, backdefs, backbacks,
wenzelm@16877
   558
                                               backfinals, backclosed)) =
obua@16308
   559
                            getnode graph backrefname
wenzelm@16877
   560
                        val (Defnode (def_ty, all_edges)) =
obua@16308
   561
                            the (get_defnode backnode backdefname)
obua@16308
   562
wenzelm@16877
   563
                        val (defnames', all_edges') =
wenzelm@16877
   564
                            case Symtab.lookup (all_edges, noderef) of
wenzelm@16877
   565
                              NONE => sys_error "finalize: corrupt backref"
wenzelm@16877
   566
                            | SOME edges =>
wenzelm@16877
   567
                              let
wenzelm@16877
   568
                                val edges' = List.filter (fn (_, _, beta, _) =>
obua@16308
   569
                                                             not (is_instance_r beta ty)) edges
wenzelm@16877
   570
                              in
wenzelm@16877
   571
                                if edges' = [] then
wenzelm@16877
   572
                                  (defnames, Symtab.delete noderef all_edges)
wenzelm@16877
   573
                                else
wenzelm@16877
   574
                                  (Symtab.update ((backdefname, ()), defnames),
wenzelm@16877
   575
                                   Symtab.update ((noderef, edges'), all_edges))
wenzelm@16877
   576
                              end
wenzelm@16877
   577
                        val defnode' = Defnode (def_ty, all_edges')
obua@16361
   578
                        val backdefs' = Symtab.update ((backdefname, defnode'), backdefs)
wenzelm@16877
   579
                        val backclosed' = if backclosed = Closed andalso
obua@16361
   580
                                             Symtab.is_empty all_edges'
obua@16361
   581
                                             andalso no_forwards backdefs'
obua@16361
   582
                                          then Final else backclosed
wenzelm@16877
   583
                        val backnode' =
obua@16361
   584
                            Node (backty, backdefs', backbacks, backfinals, backclosed')
wenzelm@16877
   585
                      in
wenzelm@16877
   586
                        (Symtab.update ((backrefname, backnode'), graph), defnames')
wenzelm@16877
   587
                      end
wenzelm@16877
   588
wenzelm@16877
   589
                  val (graph', defnames') =
obua@16308
   590
                      Symtab.foldl update_backdef ((graph, Symtab.empty), backdefnames)
wenzelm@16877
   591
                in
wenzelm@16877
   592
                  (graph', if Symtab.is_empty defnames' then backs
wenzelm@16877
   593
                           else Symtab.update ((backrefname, defnames'), backs))
wenzelm@16877
   594
                end
wenzelm@16877
   595
            val (graph', backs') = Symtab.foldl update_backref ((graph, Symtab.empty), backs)
wenzelm@16877
   596
            val Node ( _, defs, _, _, closed) = getnode graph' noderef
obua@16361
   597
            val closed = if closed = Closed andalso no_forwards defs then Final else closed
wenzelm@16877
   598
            val graph' = Symtab.update ((noderef, Node (nodety, defs, backs',
obua@16361
   599
                                                        finals, closed)), graph')
obua@16361
   600
            val history' = (Finalize (noderef, ty)) :: history
wenzelm@16877
   601
          in
wenzelm@16877
   602
            (cost+1, axmap, history', graph')
wenzelm@16877
   603
          end
obua@16198
   604
      end
wenzelm@16877
   605
wenzelm@16877
   606
fun finalize'' g (noderef, ty) = finalize' g (noderef, checkT ty)
obua@16308
   607
obua@16826
   608
fun update_axname ax orig_ax (cost, axmap, history, graph) =
obua@16826
   609
  (cost, Symtab.update ((ax, orig_ax), axmap), history, graph)
obua@16826
   610
obua@16361
   611
fun merge' (Declare cty, g) = declare' g cty
wenzelm@16877
   612
  | merge' (Define (name, ty, axname, orig_axname, body), g as (cost, axmap, history, graph)) =
obua@16198
   613
    (case Symtab.lookup (graph, name) of
obua@16826
   614
       NONE => define' (update_axname axname orig_axname g) (name, ty) axname orig_axname body
wenzelm@16877
   615
     | SOME (Node (_, defs, _, _, _)) =>
obua@16198
   616
       (case Symtab.lookup (defs, axname) of
wenzelm@16877
   617
          NONE => define' (update_axname axname orig_axname g) (name, ty) axname orig_axname body
wenzelm@16877
   618
        | SOME _ => g))
wenzelm@16877
   619
  | merge' (Finalize finals, g) = finalize' g finals
wenzelm@16877
   620
wenzelm@16877
   621
fun merge'' (g1 as (cost1, _, actions1, _)) (g2 as (cost2, _, actions2, _)) =
obua@16308
   622
    if cost1 < cost2 then
obua@16308
   623
      foldr merge' g2 actions1
obua@16308
   624
    else
obua@16308
   625
      foldr merge' g1 actions2
wenzelm@16877
   626
wenzelm@16877
   627
fun finals (_, _, history, graph) =
wenzelm@16877
   628
    Symtab.foldl
wenzelm@16877
   629
      (fn (finals, (name, Node(_, _, _, ftys, _))) =>
wenzelm@16877
   630
          Symtab.update_new ((name, ftys), finals))
obua@16198
   631
      (Symtab.empty, graph)
obua@16158
   632
wenzelm@16877
   633
fun overloading_info (_, axmap, _, graph) c =
obua@16743
   634
    let
obua@16743
   635
      fun translate (ax, Defnode (ty, _)) = (the (Symtab.lookup (axmap, ax)), ty)
obua@16743
   636
    in
obua@16743
   637
      case Symtab.lookup (graph, c) of
obua@16743
   638
        NONE => NONE
obua@16743
   639
      | SOME (Node (ty, defnodes, _, _, state)) =>
obua@16743
   640
        SOME (ty, map translate (Symtab.dest defnodes), state)
obua@16743
   641
    end
wenzelm@16877
   642
wenzelm@16877
   643
fun fast_overloading_info (_, _, _, graph) c =
wenzelm@16877
   644
    let
obua@16826
   645
      fun count (c, _) = c+1
obua@16766
   646
    in
obua@16766
   647
      case Symtab.lookup (graph, c) of
obua@16826
   648
        NONE => NONE
obua@16826
   649
      | SOME (Node (ty, defnodes, _, _, state)) =>
obua@16826
   650
        SOME (ty, Symtab.foldl count (0, defnodes), state)
obua@16766
   651
    end
obua@16743
   652
wenzelm@16877
   653
wenzelm@16877
   654
wenzelm@16877
   655
(** diagnostics **)
wenzelm@16877
   656
wenzelm@16877
   657
fun pretty_const pp (c, T) =
wenzelm@16877
   658
 [Pretty.str c, Pretty.str " ::", Pretty.brk 1,
wenzelm@16877
   659
  Pretty.quote (Pretty.typ pp (Type.freeze_type T))];    (* FIXME zero indexes!? *)
wenzelm@16877
   660
wenzelm@16877
   661
fun pretty_path pp path = fold_rev (fn (T, c, def) =>
wenzelm@16877
   662
  fn [] => [Pretty.block (pretty_const pp (c, T))]
wenzelm@16877
   663
   | prts => Pretty.block (pretty_const pp (c, T) @
wenzelm@16877
   664
      [Pretty.brk 1, Pretty.str ("depends via " ^ quote def ^ " on")]) :: prts) path [];
wenzelm@16877
   665
wenzelm@16877
   666
fun chain_history_msg s =    (* FIXME huh!? *)
wenzelm@16877
   667
  if chain_history () then s ^ ": "
wenzelm@16877
   668
  else s ^ " (set DEFS_CHAIN_HISTORY=ON for full history): ";
wenzelm@16877
   669
wenzelm@16877
   670
fun defs_circular pp path =
wenzelm@16877
   671
  Pretty.str (chain_history_msg "Cyclic dependency of definitions") :: pretty_path pp path
wenzelm@16877
   672
  |> Pretty.chunks |> Pretty.string_of;
wenzelm@16877
   673
wenzelm@16877
   674
fun defs_infinite_chain pp path =
wenzelm@16877
   675
  Pretty.str (chain_history_msg "Infinite chain of definitions") :: pretty_path pp path
wenzelm@16877
   676
  |> Pretty.chunks |> Pretty.string_of;
wenzelm@16877
   677
wenzelm@16877
   678
fun defs_clash def1 def2 = "Type clash in definitions " ^ quote def1 ^ " and " ^ quote def2;
wenzelm@16877
   679
wenzelm@16877
   680
fun defs_final pp const =
wenzelm@16877
   681
  (Pretty.str "Attempt to define final constant" :: Pretty.brk 1 :: pretty_const pp const)
wenzelm@16877
   682
  |> Pretty.block |> Pretty.string_of;
wenzelm@16877
   683
wenzelm@16877
   684
wenzelm@16877
   685
(* external interfaces *)
wenzelm@16877
   686
wenzelm@16877
   687
fun declare const defs =
wenzelm@16877
   688
  if_none (try (declare'' defs) const) defs;
wenzelm@16877
   689
wenzelm@16877
   690
fun define pp const name rhs defs =
wenzelm@16877
   691
  define'' defs const name rhs
wenzelm@16877
   692
    handle DEFS msg => sys_error msg
wenzelm@16877
   693
      | CIRCULAR path => error (defs_circular pp path)
wenzelm@16877
   694
      | INFINITE_CHAIN path => error (defs_infinite_chain pp path)
wenzelm@16877
   695
      | CLASH (_, def1, def2) => error (defs_clash def1 def2)
wenzelm@16877
   696
      | FINAL const => error (defs_final pp const);
wenzelm@16877
   697
wenzelm@16877
   698
fun finalize const defs =
wenzelm@16877
   699
  finalize'' defs const handle DEFS msg => sys_error msg;
wenzelm@16877
   700
wenzelm@16877
   701
fun merge pp defs1 defs2 =
wenzelm@16877
   702
  merge'' defs1 defs2
wenzelm@16877
   703
    handle CIRCULAR namess => error (defs_circular pp namess)
wenzelm@16877
   704
      | INFINITE_CHAIN namess => error (defs_infinite_chain pp namess);
wenzelm@16877
   705
obua@16108
   706
end;
wenzelm@16877
   707
obua@16308
   708
(*
obua@16108
   709
obua@16308
   710
fun tvar name = TVar ((name, 0), [])
obua@16108
   711
obua@16108
   712
val bool = Type ("bool", [])
obua@16108
   713
val int = Type ("int", [])
obua@16308
   714
val lam = Type("lam", [])
obua@16108
   715
val alpha = tvar "'a"
obua@16108
   716
val beta = tvar "'b"
obua@16108
   717
val gamma = tvar "'c"
obua@16108
   718
fun pair a b = Type ("pair", [a,b])
obua@16308
   719
fun prm a = Type ("prm", [a])
obua@16308
   720
val name = Type ("name", [])
obua@16108
   721
obua@16108
   722
val _ = print "make empty"
wenzelm@16877
   723
val g = Defs.empty
obua@16108
   724
obua@16308
   725
val _ = print "declare perm"
obua@16308
   726
val g = Defs.declare g ("perm", prm alpha --> beta --> beta)
obua@16308
   727
obua@16308
   728
val _ = print "declare permF"
obua@16308
   729
val g = Defs.declare g ("permF", prm alpha --> lam --> lam)
obua@16308
   730
obua@16308
   731
val _ = print "define perm (1)"
wenzelm@16877
   732
val g = Defs.define g ("perm", prm alpha --> (beta --> gamma) --> (beta --> gamma)) "perm_fun"
obua@16308
   733
        [("perm", prm alpha --> gamma --> gamma), ("perm", prm alpha --> beta --> beta)]
obua@16108
   734
obua@16308
   735
val _ = print "define permF (1)"
obua@16308
   736
val g = Defs.define g ("permF", prm alpha --> lam --> lam) "permF_app"
obua@16308
   737
        ([("perm", prm alpha --> lam --> lam),
obua@16308
   738
         ("perm", prm alpha --> lam --> lam),
obua@16308
   739
         ("perm", prm alpha --> lam --> lam),
obua@16308
   740
         ("perm", prm alpha --> name --> name)])
obua@16108
   741
obua@16308
   742
val _ = print "define perm (2)"
obua@16308
   743
val g = Defs.define g ("perm", prm alpha --> lam --> lam) "perm_lam"
obua@16308
   744
        [("permF", (prm alpha --> lam --> lam))]
obua@16108
   745
wenzelm@16877
   746
*)