src/Pure/defs.ML
author obua
Sun May 29 12:39:12 2005 +0200 (2005-05-29)
changeset 16108 cf468b93a02e
child 16113 692fe6595755
permissions -rw-r--r--
Implement cycle-free overloading, so that definitions cannot harm consistency any more (except of course via interaction with axioms).
     1 (*  Title:      Pure/General/defs.ML
     2     ID:         $Id$
     3     Author:     Steven Obua, TU Muenchen
     4 
     5     Checks if definitions preserve consistency of logic by enforcing that there are no cyclic definitions.
     6     The algorithm is described in 
     7     "Cycle-free Overloading in Isabelle", Steven Obua, technical report, to be written :-)
     8 *)
     9 
    10 signature DEFS = sig
    11     
    12     type graph
    13 
    14     exception DEFS of string
    15     exception CIRCULAR of (typ * string * string) list
    16     exception CLASH of string * string * string
    17     
    18     val empty : graph
    19     val declare : graph -> string -> typ -> graph  (* exception DEFS *)
    20     val define : graph -> string -> typ -> string -> (string * typ) list -> graph (* exception DEFS, CIRCULAR, CLASH *)
    21 
    22     (* the first argument should be the smaller graph *)
    23     val merge : graph -> graph -> graph (* exception CIRCULAR, CLASH *)
    24 
    25 end
    26 
    27 structure Defs :> DEFS = struct
    28 
    29 type tyenv = Type.tyenv
    30 type edgelabel = (int * typ * typ * (typ * string * string) list)
    31 type noderef = string
    32 
    33 datatype node = Node of
    34        string  (* name of constant *)
    35      * typ  (* most general type of constant *)
    36      * defnode Symtab.table  (* a table of defnodes, each corresponding to 1 definition of the constant for a particular type, 
    37                              indexed by axiom name *)
    38      * backref Symtab.table (* a table of all back references to this node, indexed by node name *)
    39      
    40 and defnode = Defnode of
    41        typ  (* type of the constant in this particular definition *)
    42      * ((noderef * (string option * edgelabel list) list) Symtab.table) (* The edges, grouped by nodes. *)
    43 
    44 and backref = Backref of
    45        noderef  (* a reference to the node that has defnodes which reference a certain node A *)
    46      * (unit Symtab.table) (* the names of the defnodes that DIRECTLY reference A. *)
    47 
    48 fun getnode graph noderef = the (Symtab.lookup (graph, noderef))
    49 fun get_nodename (Node (n, _, _ ,_)) = n
    50 fun get_nodedefs (Node (_, _, defs, _)) = defs
    51 fun get_defnode (Node (_, _, defs, _)) defname = Symtab.lookup (defs, defname)
    52 fun get_defnode' graph noderef defname = Symtab.lookup (get_nodedefs (the (Symtab.lookup (graph, noderef))), defname)
    53 fun get_nodename (Node (n, _, _ ,_)) = n
    54 
    55 
    56 (*fun t2list t = rev (Symtab.foldl (fn (l, d) => d::l) ([], t))
    57 fun tmap f t = map (fn (a,b) => (a, f b)) t
    58 fun defnode2data (Defnode (typ, table)) = ("Defnode", typ, t2list table)
    59 fun backref2data (Backref (noderef, table)) = ("Backref", noderef, map fst (t2list table))
    60 fun node2data (Node (s, t, defs, backs)) = ("Node", ("nodename", s), ("nodetyp", t), 
    61 					    ("defs", tmap defnode2data (t2list defs)), ("backs", tmap backref2data (t2list backs)))
    62 fun graph2data g = ("Graph", tmap node2data (t2list g))
    63 *)
    64 
    65 datatype graphaction = Declare of string * typ | Define of string * typ * string * (string * typ) list
    66 
    67 type graph = (graphaction list) * (node Symtab.table)
    68 
    69 val empty = ([], Symtab.empty)
    70 
    71 exception DEFS of string;
    72 exception CIRCULAR of (typ * string * string) list;
    73 exception CLASH of string * string * string;
    74 
    75 fun def_err s = raise (DEFS s)
    76 
    77 fun declare (actions, g) name ty =
    78     ((Declare (name, ty))::actions, 
    79      Symtab.update_new ((name, Node (name, Type.varifyT(Type.strip_sorts ty), Symtab.empty, Symtab.empty)), g))
    80     handle Symtab.DUP _ => def_err "declare: constant is already defined"
    81 
    82 fun rename ty1 ty2 = incr_tvar ((maxidx_of_typ ty1)+1) ty2;  
    83 
    84 fun subst_incr_tvar inc t =
    85     if (inc > 0) then 
    86 	let
    87 	    val tv = typ_tvars t
    88 	    val t' = incr_tvar inc t
    89 	    fun update_subst (((n,i), _), s) =
    90 		Vartab.update (((n, i), ([], TVar ((n, i+inc), []))), s)
    91 	in
    92 	    (t',List.foldl update_subst Vartab.empty tv)
    93 	end	
    94     else
    95 	(t, Vartab.empty)
    96 
    97 (* Rename tys2 so that tys2 and tys1 do not have any variables in common any more.
    98    As a result, return the renamed tys2' and the substitution that takes tys2 to tys2'. *)
    99 fun subst_rename max1 ty2 =
   100     let
   101         val max2 = (maxidx_of_typ ty2)
   102         val (ty2', s) = subst_incr_tvar (max1 + 1) ty2                
   103     in
   104 	(ty2', s, max1 + max2 + 1)
   105     end	       
   106 
   107 fun subst s ty = Envir.norm_type s ty
   108 
   109 fun subst_history s history = map (fn (ty, cn, dn) => (subst s ty, cn, dn)) history
   110 
   111 fun is_instance instance_ty general_ty =
   112     Type.typ_instance Type.empty_tsig (instance_ty, general_ty)
   113 
   114 fun is_instance_r instance_ty general_ty =
   115     is_instance instance_ty (rename instance_ty general_ty)
   116 
   117 fun unify ty1 ty2 = 
   118     SOME (fst (Type.unify Type.empty_tsig (Vartab.empty, 0) (ty1, ty2)))
   119     handle Type.TUNIFY => NONE
   120 
   121 (* 
   122    Unifies ty1 and ty2, renaming ty1 and ty2 so that they have greater indices than max and so that they
   123    are different. All indices in ty1 and ty2 are supposed to be less than or equal to max.
   124    Returns SOME (max', s1, s2), so that s1(ty1) = s2(ty2) and max' is greater or equal than all 
   125    indices in s1, s2, ty1, ty2.
   126 *)
   127 fun unify_r max ty1 ty2 = 
   128     let
   129 	val max =  Int.max(max, 0)
   130 	val max1 = max (* >= maxidx_of_typ ty1 *)
   131 	val max2 = max (* >= maxidx_of_typ ty2 *)
   132 	val max = Int.max(max, Int.max (max1, max2))
   133         val (ty1, s1) = subst_incr_tvar (max+1) ty1
   134 	val (ty2, s2) = subst_incr_tvar (max+max1+2) ty2
   135         val max = max+max1+max2+2	
   136 	fun merge a b = Vartab.merge (fn _ => false) (a, b)
   137     in
   138 	case unify ty1 ty2 of
   139 	    NONE => NONE
   140 	  | SOME s => SOME (max, merge s1 s, merge s2 s)
   141     end
   142 
   143 fun can_be_unified_r ty1 ty2 =
   144     let
   145 	val ty2 = rename ty1 ty2
   146     in
   147 	case unify ty1 ty2 of
   148 	    NONE => false
   149 	  | _ => true
   150     end
   151 
   152 fun can_be_unified ty1 ty2 =
   153     case unify ty1 ty2 of
   154 	NONE => false
   155       | _ => true
   156 
   157 fun checkT (Type (a, Ts)) = Type (a, map checkT Ts)
   158   | checkT (TVar ((a, 0), _)) = TVar ((a, 0), [])
   159   | checkT (TVar ((a, i), _)) = def_err "type is not clean"
   160   | checkT (TFree (a, _)) = TVar ((a, 0), [])
   161 
   162 fun forall_table P tab = Symtab.foldl (fn (true, e) => P e | (b, _) => b) (true, tab);
   163 
   164 fun label_ord NONE NONE = EQUAL
   165   | label_ord NONE (SOME _) = LESS
   166   | label_ord (SOME _) NONE = GREATER
   167   | label_ord (SOME l1) (SOME l2) = string_ord (l1,l2)
   168 
   169 fun compare_edges (e1 as (maxidx1, u1, v1, history1)) (e2 as (maxidx2, u2, v2, history2)) =
   170     let
   171 	val t1 = u1 --> v1
   172 	val t2 = u2 --> v2
   173     in
   174 	if (is_instance_r t1 t2) then
   175 	    (if is_instance_r t2 t1 then
   176 		 SOME (int_ord (length history2, length history1))
   177 	     else
   178 		 SOME LESS)
   179 	else if (is_instance_r t2 t1) then
   180 	    SOME GREATER
   181 	else
   182 	    NONE
   183     end
   184 
   185 fun merge_edges_1 (x, []) = []
   186   | merge_edges_1 (x, (y::ys)) = 
   187     (case compare_edges x y of
   188 	 SOME LESS => (y::ys)
   189        | SOME EQUAL => (y::ys)
   190        | SOME GREATER => merge_edges_1 (x, ys)
   191        | NONE => y::(merge_edges_1 (x, ys)))
   192 
   193 fun merge_edges xs ys = foldl merge_edges_1 xs ys
   194 
   195 fun pack_edges xs = merge_edges [] xs
   196 
   197 fun merge_labelled_edges [] es = es
   198   | merge_labelled_edges es [] = es
   199   | merge_labelled_edges ((l1,e1)::es1) ((l2,e2)::es2) = 
   200     (case label_ord l1 l2 of
   201 	 LESS => (l1, e1)::(merge_labelled_edges es1 ((l2, e2)::es2))
   202        | GREATER => (l2, e2)::(merge_labelled_edges ((l1, e1)::es1) es2)
   203        | EQUAL => (l1, merge_edges e1 e2)::(merge_labelled_edges es1 es2))
   204 
   205 fun defnode_edges_foldl f a defnode =
   206     let
   207 	val (Defnode (ty, def_edges)) = defnode
   208 	fun g (b, (_, (n, labelled_edges))) =
   209 	    foldl (fn ((s, edges), b') => 
   210 		      (foldl (fn (e, b'') => f ty n s e b'') b' edges))
   211 		  b
   212 		  labelled_edges		  		     
   213     in
   214 	Symtab.foldl g (a, def_edges)
   215     end	
   216 
   217 fun define (actions, graph) name ty axname body =
   218     let
   219 	val ty = checkT ty
   220 	val body = map (fn (n,t) => (n, checkT t)) body		 
   221 	val mainref = name
   222 	val mainnode  = (case Symtab.lookup (graph, mainref) of 
   223 			     NONE => def_err ("constant "^(quote mainref)^" is not declared")
   224 			   | SOME n => n)
   225 	val (Node (n, gty, defs, backs)) = mainnode
   226 	val _ = (if is_instance_r ty gty then () else def_err "type of constant does not match declared type")
   227 	fun check_def (s, Defnode (ty', _)) = 
   228 	    (if can_be_unified_r ty ty' then 
   229 		 raise (CLASH (mainref, axname, s))
   230 	     else if s = axname then
   231 	         def_err "name of axiom is already used for another definition of this constant"
   232 	     else true)
   233 	val _ = forall_table check_def defs		
   234 	(* now we know that the only thing that can prevent acceptance of the definition is a cyclic dependency *)
   235 
   236 	(* body contains the constants that this constant definition depends on. For each element of body,
   237            the function make_edges_to calculates a group of edges that connect this constant with 
   238            the constant that is denoted by the element of the body *)
   239 	fun make_edges_to (bodyn, bodyty) =
   240 	    let
   241 		val bnode = 
   242 		    (case Symtab.lookup (graph, bodyn) of 
   243 			 NONE => def_err "body of constant definition references undeclared constant"
   244 		       | SOME x => x)
   245 		val (Node (_, general_btyp, bdefs, bbacks)) = bnode
   246 	    in
   247 		case unify_r 0 bodyty general_btyp of
   248 		    NONE => NONE
   249 		  | SOME (maxidx, sigma1, sigma2) => 
   250 		    SOME (
   251 		    let
   252 			(* For each definition of the constant in the body, 
   253 			   check if the definition unifies with the type of the constant in the body. *)	                
   254 	              fun make_edges ((swallowed, l),(def_name, Defnode (def_ty, _))) =
   255 			  if swallowed then
   256 			      (swallowed, l)
   257 			  else 
   258 			      (case unify_r 0 bodyty def_ty of
   259 				   NONE => (swallowed, l)
   260 				 | SOME (maxidx, sigma1, sigma2) => 
   261 				   (is_instance bodyty def_ty,
   262 				    merge_labelled_edges l [(SOME def_name,[(maxidx, subst sigma1 ty, subst sigma2 def_ty, [])])]))
   263           	      val (swallowed, edges) = Symtab.foldl make_edges ((false, []), bdefs)
   264 		    in
   265 			if swallowed then 
   266 			    (bodyn, edges)
   267 			else
   268 			    (bodyn, [(NONE, [(maxidx, subst sigma1 ty, subst sigma2 general_btyp,[])])]@edges)
   269 		    end)
   270 	    end 
   271 
   272 	fun update_edges (b as (bodyn, bodyty), edges) =
   273 	    (case make_edges_to b of
   274 		 NONE => edges
   275 	       | SOME m =>
   276 		 (case Symtab.lookup (edges, bodyn) of
   277 		      NONE => Symtab.update ((bodyn, m), edges)
   278 		    | SOME (_, es') => 
   279 		      let 
   280 			  val (_, es) = m
   281 			  val es = merge_labelled_edges es es'
   282 		      in
   283 			  Symtab.update ((bodyn, (bodyn, es)), edges)
   284 		      end
   285 		 )
   286 	    )
   287 
   288 	val edges = foldl update_edges Symtab.empty body
   289 
   290 	fun insert_edge edges (nodename, (defname_opt, edge)) = 
   291 	    let
   292 		val newlink = [(defname_opt, [edge])]
   293 	    in
   294 		case Symtab.lookup (edges, nodename) of
   295 		    NONE => Symtab.update ((nodename, (nodename, newlink)), edges)		    
   296 		  | SOME (_, links) => 
   297 		    let
   298 			val links' = merge_labelled_edges links newlink
   299 		    in
   300 			Symtab.update ((nodename, (nodename, links')), edges)
   301 		    end
   302 	    end				    
   303 
   304         (* We constructed all direct edges that this defnode has. 
   305            Now we have to construct the transitive hull by going a single step further. *)
   306 
   307         val thisDefnode = Defnode (ty, edges)
   308 
   309 	fun make_trans_edges _ noderef defname_opt (max1, alpha1, beta1, history1) edges = 
   310 	    case defname_opt of 
   311 		NONE => edges
   312 	      | SOME defname => 		
   313 		let
   314 		    val defnode = the (get_defnode' graph noderef defname)
   315 		    fun make_trans_edge _ noderef2 defname_opt2 (max2, alpha2, beta2, history2) edges =
   316 			case unify_r (Int.max (max1, max2)) beta1 alpha2 of
   317 			    NONE => edges
   318 			  | SOME (max, sleft, sright) =>
   319 			    insert_edge edges (noderef2, 
   320 					       (defname_opt2, 							  
   321 						(max, subst sleft alpha1, subst sright beta2, 
   322 						 (subst_history sleft history1)@
   323 						 ((subst sleft beta1, noderef, defname)::
   324 						  (subst_history sright history2)))))
   325 		in
   326 		    defnode_edges_foldl make_trans_edge edges defnode
   327 		end
   328 
   329 	val edges = defnode_edges_foldl make_trans_edges edges thisDefnode
   330 
   331 	val thisDefnode = Defnode (ty, edges)
   332 
   333 	(* We also have to add the backreferences that this new defnode induces. *)
   334 	    
   335 	fun hasNONElink ((NONE, _)::_) = true
   336 	  | hasNONElink _ = false
   337 	
   338 	fun install_backref graph noderef pointingnoderef pointingdefname = 
   339 	    let
   340 		val (Node (pname, _, _, _)) = getnode graph pointingnoderef
   341 		val (Node (name, ty, defs, backs)) = getnode graph noderef
   342 	    in
   343 		case Symtab.lookup (backs, pname) of
   344 		    NONE => 
   345 		    let 
   346 			val defnames = Symtab.update ((pointingdefname, ()), Symtab.empty)
   347 			val backs = Symtab.update ((pname, Backref (pointingnoderef, defnames)), backs)
   348 		    in
   349 			Symtab.update ((name, Node (name, ty, defs, backs)), graph) 			
   350 		    end
   351 		  | SOME (Backref (pointingnoderef, defnames)) =>
   352 		    let
   353 			val defnames = Symtab.update_new ((pointingdefname, ()), defnames)
   354 			val backs = Symtab.update ((pname, Backref (pointingnoderef, defnames)), backs)
   355 		    in
   356 			Symtab.update ((name, Node (name, ty, defs, backs)), graph)
   357 		    end
   358 		    handle Symtab.DUP _ => graph
   359 	    end
   360 
   361 	fun install_backrefs (graph, (_, (noderef, labelled_edges))) =
   362 	    if hasNONElink labelled_edges then
   363 		install_backref graph noderef mainref axname
   364 	    else 
   365 		graph
   366 
   367         val graph = Symtab.foldl install_backrefs (graph, edges)
   368 
   369         val (Node (_, _, _, backs)) = getnode graph mainref
   370 	val graph = Symtab.update ((mainref, Node (n, gty, Symtab.update_new ((axname, thisDefnode), defs), backs)), graph)
   371 		    
   372 	(* Now we have to check all backreferences to this node and inform them about the new defnode. 
   373 	   In this section we also check for circularity. *)
   374         fun update_backrefs ((backs, newedges), (nodename, Backref (noderef, defnames))) =	    
   375 	    let
   376 		val node = getnode graph noderef
   377 		fun update_defs ((defnames, newedges),(defname, _)) =
   378 		    let
   379 			val (Defnode (_, defnode_edges)) = the (get_defnode node defname)
   380 			val (_, labelled_edges) = the (Symtab.lookup (defnode_edges, n))
   381 						      
   382 			(* the type of thisDefnode is ty *)
   383 			fun update (e as (max, alpha, beta, history), (none_edges, this_edges)) = 
   384 			    case unify_r max beta ty of
   385 				NONE => (e::none_edges, this_edges)
   386 			      | SOME (max', s_beta, s_ty) =>
   387 				let
   388 				    val alpha' = subst s_beta alpha
   389 				    val ty' = subst s_ty ty				      
   390 				    val _ = 
   391 					if noderef = mainref andalso defname = axname then
   392 					    (case unify alpha' ty' of
   393 						 NONE => ()
   394 					       | SOME s => raise (CIRCULAR (
   395 								  (subst s alpha', mainref, axname)::
   396 								  (subst_history s (subst_history s_beta history))@
   397 								  [(subst s ty', mainref, axname)])))
   398 					else ()
   399 				    val edge = (max', alpha', ty', subst_history s_beta history)
   400 				in
   401 				    if is_instance_r beta ty then 
   402 					(none_edges, edge::this_edges)
   403 				    else
   404 					(e::none_edges, edge::this_edges)
   405 				end					    			   			    
   406 		    in
   407 			case labelled_edges of 
   408 			    ((NONE, edges)::_) => 
   409 			    let
   410 				val (none_edges, this_edges) = foldl update ([], []) edges
   411 				val defnames = if none_edges = [] then defnames else Symtab.update_new ((defname, ()), defnames) 
   412 			    in
   413 				(defnames, (defname, none_edges, this_edges)::newedges)
   414 			    end			    
   415 			  | _ => def_err "update_defs, internal error, corrupt backrefs"
   416 		    end
   417 		    
   418 		val (defnames, newedges') = Symtab.foldl update_defs ((Symtab.empty, []), defnames)
   419 	    in
   420 		if Symtab.is_empty defnames then 
   421 		    (backs, (noderef, newedges')::newedges)
   422 		else
   423 		    let
   424 			val backs = Symtab.update_new ((nodename, Backref (noderef, defnames)), backs)
   425 		    in
   426 			(backs, newedges)
   427 		    end
   428 	    end
   429 	    
   430 
   431 	val (backs, newedges) = Symtab.foldl update_backrefs ((Symtab.empty, []), backs)
   432 						 
   433 	(* If a Circular exception is thrown then we never reach this point. *)
   434         (* Ok, the definition is consistent, let's update this node. *)
   435 	val graph = Symtab.update ((mainref, Node (n, gty, Symtab.update ((axname, thisDefnode), defs), backs)), graph)
   436 
   437         (* Furthermore, update all the other nodes that backreference this node. *)
   438         fun final_update_backrefs graph noderef defname none_edges this_edges =
   439 	    let
   440 		val node = getnode graph noderef
   441 		val (Node (nodename, nodety, defs, backs)) = node
   442 		val (Defnode (defnode_ty, defnode_edges)) = the (get_defnode node defname)
   443 		val (_, defnode_links) = the (Symtab.lookup (defnode_edges, n))
   444 
   445 		fun update edges none_edges this_edges =
   446 		    let 
   447 			val u = merge_labelled_edges edges [(SOME axname, pack_edges this_edges)]
   448 		    in
   449 			if none_edges = [] then
   450 			    u
   451 			else
   452 			    (NONE, pack_edges none_edges)::u
   453 		    end
   454 		    
   455 		val defnode_links' = 
   456 		    case defnode_links of 
   457 			((NONE, _) :: edges) => update edges none_edges this_edges
   458 		      | edges => update edges none_edges this_edges
   459 		val defnode_edges' = Symtab.update ((n, (mainref, defnode_links')), defnode_edges)
   460 		val defs' = Symtab.update ((defname, Defnode (defnode_ty, defnode_edges')), defs)
   461 	    in
   462 		Symtab.update ((nodename, Node (nodename, nodety, defs', backs)), graph)
   463 	    end
   464 
   465 	val graph = foldl (fn ((noderef, newedges),graph) => foldl (fn ((defname, none_edges, this_edges), graph) =>
   466            final_update_backrefs graph noderef defname none_edges this_edges) graph newedges) graph newedges		    
   467 
   468     in	    
   469 	((Define (name, ty, axname, body))::actions, graph)	   
   470     end 
   471 
   472     
   473     fun merge' (Declare (name, ty), g) = (declare g name ty handle _ => g)
   474       | merge' (Define (name, ty, axname, body), g as (_, graph)) = 
   475 	(case Symtab.lookup (graph, name) of
   476 	     NONE => define g name ty axname body
   477 	   | SOME (Node (_, _, defs, _)) => 
   478 	     (case Symtab.lookup (defs, axname) of
   479 		  NONE => define g name ty axname body
   480 		| SOME _ => g))
   481 	
   482     fun merge (actions, _) g = foldr merge' g actions
   483 
   484 end;
   485 		
   486 
   487 
   488 (*fun tvar name = TVar ((name, 0), [])
   489 
   490 val bool = Type ("bool", [])
   491 val int = Type ("int", [])
   492 val alpha = tvar "'a"
   493 val beta = tvar "'b"
   494 val gamma = tvar "'c"
   495 fun pair a b = Type ("pair", [a,b])
   496 
   497 val _ = print "make empty"
   498 val g = Defs.empty 
   499 
   500 val _ = print "declare"
   501 val g = Defs.declare g "M" (alpha --> bool)
   502 val g = Defs.declare g "N" (beta --> bool)
   503 
   504 val _ = print "define"
   505 val g = Defs.define g "N" (alpha --> bool) "defN" [("M", alpha --> bool)]
   506 val g = Defs.define g "M" (alpha --> bool) "defM" [("N", int --> alpha)]
   507 
   508 val g = Defs.declare g "0" alpha
   509 val g = Defs.define g "0" (pair alpha beta) "zp" [("0", alpha), ("0", beta)]*)
   510 
   511