src/Provers/trancl.ML
author wenzelm
Tue Sep 29 16:24:36 2009 +0200 (2009-09-29)
changeset 32740 9dd0a2f83429
parent 32285 ab9b66c2bbca
child 32768 e4a3f9c3d4f5
permissions -rw-r--r--
explicit indication of Unsynchronized.ref;
     1 (*
     2     Title:      Transitivity reasoner for transitive closures of relations
     3     Author:     Oliver Kutter, TU Muenchen
     4 *)
     5 
     6 (*
     7 
     8 The packages provides tactics trancl_tac and rtrancl_tac that prove
     9 goals of the form
    10 
    11    (x,y) : r^+     and     (x,y) : r^* (rtrancl_tac only)
    12 
    13 from premises of the form
    14 
    15    (x,y) : r,     (x,y) : r^+     and     (x,y) : r^* (rtrancl_tac only)
    16 
    17 by reflexivity and transitivity.  The relation r is determined by inspecting
    18 the conclusion.
    19 
    20 The package is implemented as an ML functor and thus not limited to
    21 particular constructs for transitive and reflexive-transitive
    22 closures, neither need relations be represented as sets of pairs.  In
    23 order to instantiate the package for transitive closure only, supply
    24 dummy theorems to the additional rules for reflexive-transitive
    25 closures, and don't use rtrancl_tac!
    26 
    27 *)
    28 
    29 signature TRANCL_ARITH =
    30 sig
    31 
    32   (* theorems for transitive closure *)
    33 
    34   val r_into_trancl : thm
    35       (* (a,b) : r ==> (a,b) : r^+ *)
    36   val trancl_trans : thm
    37       (* [| (a,b) : r^+ ; (b,c) : r^+ |] ==> (a,c) : r^+ *)
    38 
    39   (* additional theorems for reflexive-transitive closure *)
    40 
    41   val rtrancl_refl : thm
    42       (* (a,a): r^* *)
    43   val r_into_rtrancl : thm
    44       (* (a,b) : r ==> (a,b) : r^* *)
    45   val trancl_into_rtrancl : thm
    46       (* (a,b) : r^+ ==> (a,b) : r^* *)
    47   val rtrancl_trancl_trancl : thm
    48       (* [| (a,b) : r^* ; (b,c) : r^+ |] ==> (a,c) : r^+ *)
    49   val trancl_rtrancl_trancl : thm
    50       (* [| (a,b) : r^+ ; (b,c) : r^* |] ==> (a,c) : r^+ *)
    51   val rtrancl_trans : thm
    52       (* [| (a,b) : r^* ; (b,c) : r^* |] ==> (a,c) : r^* *)
    53 
    54   (* decomp: decompose a premise or conclusion
    55 
    56      Returns one of the following:
    57 
    58      NONE if not an instance of a relation,
    59      SOME (x, y, r, s) if instance of a relation, where
    60        x: left hand side argument, y: right hand side argument,
    61        r: the relation,
    62        s: the kind of closure, one of
    63             "r":   the relation itself,
    64             "r^+": transitive closure of the relation,
    65             "r^*": reflexive-transitive closure of the relation
    66   *)
    67 
    68   val decomp: term ->  (term * term * term * string) option
    69 
    70 end;
    71 
    72 signature TRANCL_TAC =
    73 sig
    74   val trancl_tac: Proof.context -> int -> tactic
    75   val rtrancl_tac: Proof.context -> int -> tactic
    76 end;
    77 
    78 functor Trancl_Tac(Cls: TRANCL_ARITH): TRANCL_TAC =
    79 struct
    80 
    81 
    82 datatype proof
    83   = Asm of int
    84   | Thm of proof list * thm;
    85 
    86 exception Cannot; (* internal exception: raised if no proof can be found *)
    87 
    88 fun decomp t = Option.map (fn (x, y, rel, r) =>
    89   (Envir.beta_eta_contract x, Envir.beta_eta_contract y,
    90    Envir.beta_eta_contract rel, r)) (Cls.decomp t);
    91 
    92 fun prove thy r asms =
    93   let
    94     fun inst thm =
    95       let val SOME (_, _, r', _) = decomp (concl_of thm)
    96       in Drule.cterm_instantiate [(cterm_of thy r', cterm_of thy r)] thm end;
    97     fun pr (Asm i) = List.nth (asms, i)
    98       | pr (Thm (prfs, thm)) = map pr prfs MRS inst thm
    99   in pr end;
   100 
   101 
   102 (* Internal datatype for inequalities *)
   103 datatype rel
   104    = Trans  of term * term * proof  (* R^+ *)
   105    | RTrans of term * term * proof; (* R^* *)
   106 
   107  (* Misc functions for datatype rel *)
   108 fun lower (Trans (x, _, _)) = x
   109   | lower (RTrans (x,_,_)) = x;
   110 
   111 fun upper (Trans (_, y, _)) = y
   112   | upper (RTrans (_,y,_)) = y;
   113 
   114 fun getprf   (Trans   (_, _, p)) = p
   115 |   getprf   (RTrans (_,_, p)) = p;
   116 
   117 (* ************************************************************************ *)
   118 (*                                                                          *)
   119 (*  mkasm_trancl Rel (t,n): term -> (term , int) -> rel list                *)
   120 (*                                                                          *)
   121 (*  Analyse assumption t with index n with respect to relation Rel:         *)
   122 (*  If t is of the form "(x, y) : Rel" (or Rel^+), translate to             *)
   123 (*  an object (singleton list) of internal datatype rel.                    *)
   124 (*  Otherwise return empty list.                                            *)
   125 (*                                                                          *)
   126 (* ************************************************************************ *)
   127 
   128 fun mkasm_trancl  Rel  (t, n) =
   129   case decomp t of
   130     SOME (x, y, rel,r) => if rel aconv Rel then
   131 
   132     (case r of
   133       "r"   => [Trans (x,y, Thm([Asm n], Cls.r_into_trancl))]
   134     | "r+"  => [Trans (x,y, Asm n)]
   135     | "r*"  => []
   136     | _     => error ("trancl_tac: unknown relation symbol"))
   137     else []
   138   | NONE => [];
   139 
   140 (* ************************************************************************ *)
   141 (*                                                                          *)
   142 (*  mkasm_rtrancl Rel (t,n): term -> (term , int) -> rel list               *)
   143 (*                                                                          *)
   144 (*  Analyse assumption t with index n with respect to relation Rel:         *)
   145 (*  If t is of the form "(x, y) : Rel" (or Rel^+ or Rel^* ), translate to   *)
   146 (*  an object (singleton list) of internal datatype rel.                    *)
   147 (*  Otherwise return empty list.                                            *)
   148 (*                                                                          *)
   149 (* ************************************************************************ *)
   150 
   151 fun mkasm_rtrancl Rel (t, n) =
   152   case decomp t of
   153    SOME (x, y, rel, r) => if rel aconv Rel then
   154     (case r of
   155       "r"   => [ Trans (x,y, Thm([Asm n], Cls.r_into_trancl))]
   156     | "r+"  => [ Trans (x,y, Asm n)]
   157     | "r*"  => [ RTrans(x,y, Asm n)]
   158     | _     => error ("rtrancl_tac: unknown relation symbol" ))
   159    else []
   160   | NONE => [];
   161 
   162 (* ************************************************************************ *)
   163 (*                                                                          *)
   164 (*  mkconcl_trancl t: term -> (term, rel, proof)                            *)
   165 (*  mkconcl_rtrancl t: term -> (term, rel, proof)                           *)
   166 (*                                                                          *)
   167 (*  Analyse conclusion t:                                                   *)
   168 (*    - must be of form "(x, y) : r^+ (or r^* for rtrancl)                  *)
   169 (*    - returns r                                                           *)
   170 (*    - conclusion in internal form                                         *)
   171 (*    - proof object                                                        *)
   172 (*                                                                          *)
   173 (* ************************************************************************ *)
   174 
   175 fun mkconcl_trancl  t =
   176   case decomp t of
   177     SOME (x, y, rel, r) => (case r of
   178       "r+"  => (rel, Trans (x,y, Asm ~1), Asm 0)
   179     | _     => raise Cannot)
   180   | NONE => raise Cannot;
   181 
   182 fun mkconcl_rtrancl  t =
   183   case decomp t of
   184     SOME (x,  y, rel,r ) => (case r of
   185       "r+"  => (rel, Trans (x,y, Asm ~1),  Asm 0)
   186     | "r*"  => (rel, RTrans (x,y, Asm ~1), Asm 0)
   187     | _     => raise Cannot)
   188   | NONE => raise Cannot;
   189 
   190 (* ************************************************************************ *)
   191 (*                                                                          *)
   192 (*  makeStep (r1, r2): rel * rel -> rel                                     *)
   193 (*                                                                          *)
   194 (*  Apply transitivity to r1 and r2, obtaining a new element of r^+ or r^*, *)
   195 (*  according the following rules:                                          *)
   196 (*                                                                          *)
   197 (* ( (a, b) : r^+ , (b,c) : r^+ ) --> (a,c) : r^+                           *)
   198 (* ( (a, b) : r^* , (b,c) : r^+ ) --> (a,c) : r^+                           *)
   199 (* ( (a, b) : r^+ , (b,c) : r^* ) --> (a,c) : r^+                           *)
   200 (* ( (a, b) : r^* , (b,c) : r^* ) --> (a,c) : r^*                           *)
   201 (*                                                                          *)
   202 (* ************************************************************************ *)
   203 
   204 fun makeStep (Trans (a,_,p), Trans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.trancl_trans))
   205 (* refl. + trans. cls. rules *)
   206 |   makeStep (RTrans (a,_,p), Trans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.rtrancl_trancl_trancl))
   207 |   makeStep (Trans (a,_,p), RTrans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.trancl_rtrancl_trancl))
   208 |   makeStep (RTrans (a,_,p), RTrans(_,c,q))  = RTrans (a,c, Thm ([p,q], Cls.rtrancl_trans));
   209 
   210 (* ******************************************************************* *)
   211 (*                                                                     *)
   212 (* transPath (Clslist, Cls): (rel  list * rel) -> rel                  *)
   213 (*                                                                     *)
   214 (* If a path represented by a list of elements of type rel is found,   *)
   215 (* this needs to be contracted to a single element of type rel.        *)
   216 (* Prior to each transitivity step it is checked whether the step is   *)
   217 (* valid.                                                              *)
   218 (*                                                                     *)
   219 (* ******************************************************************* *)
   220 
   221 fun transPath ([],acc) = acc
   222 |   transPath (x::xs,acc) = transPath (xs, makeStep(acc,x))
   223 
   224 (* ********************************************************************* *)
   225 (* Graph functions                                                       *)
   226 (* ********************************************************************* *)
   227 
   228 (* *********************************************************** *)
   229 (* Functions for constructing graphs                           *)
   230 (* *********************************************************** *)
   231 
   232 fun addEdge (v,d,[]) = [(v,d)]
   233 |   addEdge (v,d,((u,dl)::el)) = if v aconv u then ((v,d@dl)::el)
   234     else (u,dl):: (addEdge(v,d,el));
   235 
   236 (* ********************************************************************** *)
   237 (*                                                                        *)
   238 (* mkGraph constructs from a list of objects of type rel  a graph g       *)
   239 (* and a list of all edges with label r+.                                 *)
   240 (*                                                                        *)
   241 (* ********************************************************************** *)
   242 
   243 fun mkGraph [] = ([],[])
   244 |   mkGraph ys =
   245  let
   246   fun buildGraph ([],g,zs) = (g,zs)
   247   |   buildGraph (x::xs, g, zs) =
   248         case x of (Trans (_,_,_)) =>
   249                buildGraph (xs, addEdge((upper x), [],(addEdge ((lower x),[((upper x),x)],g))), x::zs)
   250         | _ => buildGraph (xs, addEdge((upper x), [],(addEdge ((lower x),[((upper x),x)],g))), zs)
   251 in buildGraph (ys, [], []) end;
   252 
   253 (* *********************************************************************** *)
   254 (*                                                                         *)
   255 (* adjacent g u : (''a * 'b list ) list -> ''a -> 'b list                  *)
   256 (*                                                                         *)
   257 (* List of successors of u in graph g                                      *)
   258 (*                                                                         *)
   259 (* *********************************************************************** *)
   260 
   261 fun adjacent eq_comp ((v,adj)::el) u =
   262     if eq_comp (u, v) then adj else adjacent eq_comp el u
   263 |   adjacent _  []  _ = []
   264 
   265 (* *********************************************************************** *)
   266 (*                                                                         *)
   267 (* dfs eq_comp g u v:                                                      *)
   268 (* ('a * 'a -> bool) -> ('a  *( 'a * rel) list) list ->                    *)
   269 (* 'a -> 'a -> (bool * ('a * rel) list)                                    *)
   270 (*                                                                         *)
   271 (* Depth first search of v from u.                                         *)
   272 (* Returns (true, path(u, v)) if successful, otherwise (false, []).        *)
   273 (*                                                                         *)
   274 (* *********************************************************************** *)
   275 
   276 fun dfs eq_comp g u v =
   277  let
   278     val pred = Unsynchronized.ref [];
   279     val visited = Unsynchronized.ref [];
   280 
   281     fun been_visited v = exists (fn w => eq_comp (w, v)) (!visited)
   282 
   283     fun dfs_visit u' =
   284     let val _ = visited := u' :: (!visited)
   285 
   286     fun update (x,l) = let val _ = pred := (x,l) ::(!pred) in () end;
   287 
   288     in if been_visited v then ()
   289     else (app (fn (v',l) => if been_visited v' then () else (
   290        update (v',l);
   291        dfs_visit v'; ()) )) (adjacent eq_comp g u')
   292      end
   293   in
   294     dfs_visit u;
   295     if (been_visited v) then (true, (!pred)) else (false , [])
   296   end;
   297 
   298 (* *********************************************************************** *)
   299 (*                                                                         *)
   300 (* transpose g:                                                            *)
   301 (* (''a * ''a list) list -> (''a * ''a list) list                          *)
   302 (*                                                                         *)
   303 (* Computes transposed graph g' from g                                     *)
   304 (* by reversing all edges u -> v to v -> u                                 *)
   305 (*                                                                         *)
   306 (* *********************************************************************** *)
   307 
   308 fun transpose eq_comp g =
   309   let
   310    (* Compute list of reversed edges for each adjacency list *)
   311    fun flip (u,(v,l)::el) = (v,(u,l)) :: flip (u,el)
   312      | flip (_,nil) = nil
   313 
   314    (* Compute adjacency list for node u from the list of edges
   315       and return a likewise reduced list of edges.  The list of edges
   316       is searches for edges starting from u, and these edges are removed. *)
   317    fun gather (u,(v,w)::el) =
   318     let
   319      val (adj,edges) = gather (u,el)
   320     in
   321      if eq_comp (u, v) then (w::adj,edges)
   322      else (adj,(v,w)::edges)
   323     end
   324    | gather (_,nil) = (nil,nil)
   325 
   326    (* For every node in the input graph, call gather to find all reachable
   327       nodes in the list of edges *)
   328    fun assemble ((u,_)::el) edges =
   329        let val (adj,edges) = gather (u,edges)
   330        in (u,adj) :: assemble el edges
   331        end
   332      | assemble nil _ = nil
   333 
   334    (* Compute, for each adjacency list, the list with reversed edges,
   335       and concatenate these lists. *)
   336    val flipped = List.foldr (op @) nil (map flip g)
   337 
   338  in assemble g flipped end
   339 
   340 (* *********************************************************************** *)
   341 (*                                                                         *)
   342 (* dfs_reachable eq_comp g u:                                              *)
   343 (* (int * int list) list -> int -> int list                                *)
   344 (*                                                                         *)
   345 (* Computes list of all nodes reachable from u in g.                       *)
   346 (*                                                                         *)
   347 (* *********************************************************************** *)
   348 
   349 fun dfs_reachable eq_comp g u =
   350  let
   351   (* List of vertices which have been visited. *)
   352   val visited  = Unsynchronized.ref nil;
   353 
   354   fun been_visited v = exists (fn w => eq_comp (w, v)) (!visited)
   355 
   356   fun dfs_visit g u  =
   357       let
   358    val _ = visited := u :: !visited
   359    val descendents =
   360        List.foldr (fn ((v,l),ds) => if been_visited v then ds
   361             else v :: dfs_visit g v @ ds)
   362         nil (adjacent eq_comp g u)
   363    in  descendents end
   364 
   365  in u :: dfs_visit g u end;
   366 
   367 (* *********************************************************************** *)
   368 (*                                                                         *)
   369 (* dfs_term_reachable g u:                                                  *)
   370 (* (term * term list) list -> term -> term list                            *)
   371 (*                                                                         *)
   372 (* Computes list of all nodes reachable from u in g.                       *)
   373 (*                                                                         *)
   374 (* *********************************************************************** *)
   375 
   376 fun dfs_term_reachable g u = dfs_reachable (op aconv) g u;
   377 
   378 (* ************************************************************************ *)
   379 (*                                                                          *)
   380 (* findPath x y g: Term.term -> Term.term ->                                *)
   381 (*                  (Term.term * (Term.term * rel list) list) ->            *)
   382 (*                  (bool, rel list)                                        *)
   383 (*                                                                          *)
   384 (*  Searches a path from vertex x to vertex y in Graph g, returns true and  *)
   385 (*  the list of edges if path is found, otherwise false and nil.            *)
   386 (*                                                                          *)
   387 (* ************************************************************************ *)
   388 
   389 fun findPath x y g =
   390   let
   391    val (found, tmp) =  dfs (op aconv) g x y ;
   392    val pred = map snd tmp;
   393 
   394    fun path x y  =
   395     let
   396          (* find predecessor u of node v and the edge u -> v *)
   397 
   398       fun lookup v [] = raise Cannot
   399       |   lookup v (e::es) = if (upper e) aconv v then e else lookup v es;
   400 
   401       (* traverse path backwards and return list of visited edges *)
   402       fun rev_path v =
   403         let val l = lookup v pred
   404             val u = lower l;
   405         in
   406           if u aconv x then [l] else (rev_path u) @ [l]
   407         end
   408 
   409     in rev_path y end;
   410 
   411    in
   412 
   413 
   414       if found then ( (found, (path x y) )) else (found,[])
   415 
   416 
   417 
   418    end;
   419 
   420 (* ************************************************************************ *)
   421 (*                                                                          *)
   422 (* findRtranclProof g tranclEdges subgoal:                                  *)
   423 (* (Term.term * (Term.term * rel list) list) -> rel -> proof list           *)
   424 (*                                                                          *)
   425 (* Searches in graph g a proof for subgoal.                                 *)
   426 (*                                                                          *)
   427 (* ************************************************************************ *)
   428 
   429 fun findRtranclProof g tranclEdges subgoal =
   430    case subgoal of (RTrans (x,y,_)) => if x aconv y then [Thm ([], Cls.rtrancl_refl)] else (
   431      let val (found, path) = findPath (lower subgoal) (upper subgoal) g
   432      in
   433        if found then (
   434           let val path' = (transPath (tl path, hd path))
   435           in
   436 
   437             case path' of (Trans (_,_,p)) => [Thm ([p], Cls.trancl_into_rtrancl )]
   438             | _ => [getprf path']
   439 
   440           end
   441        )
   442        else raise Cannot
   443      end
   444    )
   445 
   446 | (Trans (x,y,_)) => (
   447 
   448   let
   449    val Vx = dfs_term_reachable g x;
   450    val g' = transpose (op aconv) g;
   451    val Vy = dfs_term_reachable g' y;
   452 
   453    fun processTranclEdges [] = raise Cannot
   454    |   processTranclEdges (e::es) =
   455           if (upper e) mem Vx andalso (lower e) mem Vx
   456           andalso (upper e) mem Vy andalso (lower e) mem Vy
   457           then (
   458 
   459 
   460             if (lower e) aconv x then (
   461               if (upper e) aconv y then (
   462                   [(getprf e)]
   463               )
   464               else (
   465                   let
   466                     val (found,path) = findPath (upper e) y g
   467                   in
   468 
   469                    if found then (
   470                        [getprf (transPath (path, e))]
   471                       ) else processTranclEdges es
   472 
   473                   end
   474               )
   475             )
   476             else if (upper e) aconv y then (
   477                let val (xufound,xupath) = findPath x (lower e) g
   478                in
   479 
   480                   if xufound then (
   481 
   482                     let val xuRTranclEdge = transPath (tl xupath, hd xupath)
   483                             val xyTranclEdge = makeStep(xuRTranclEdge,e)
   484 
   485                                 in [getprf xyTranclEdge] end
   486 
   487                  ) else processTranclEdges es
   488 
   489                end
   490             )
   491             else (
   492 
   493                 let val (xufound,xupath) = findPath x (lower e) g
   494                     val (vyfound,vypath) = findPath (upper e) y g
   495                  in
   496                     if xufound then (
   497                          if vyfound then (
   498                             let val xuRTranclEdge = transPath (tl xupath, hd xupath)
   499                                 val vyRTranclEdge = transPath (tl vypath, hd vypath)
   500                                 val xyTranclEdge = makeStep (makeStep(xuRTranclEdge,e),vyRTranclEdge)
   501 
   502                                 in [getprf xyTranclEdge] end
   503 
   504                          ) else processTranclEdges es
   505                     )
   506                     else processTranclEdges es
   507                  end
   508             )
   509           )
   510           else processTranclEdges es;
   511    in processTranclEdges tranclEdges end )
   512 | _ => raise Cannot
   513 
   514 
   515 fun solveTrancl (asms, concl) =
   516  let val (g,_) = mkGraph asms
   517  in
   518   let val (_, subgoal, _) = mkconcl_trancl concl
   519       val (found, path) = findPath (lower subgoal) (upper subgoal) g
   520   in
   521     if found then  [getprf (transPath (tl path, hd path))]
   522     else raise Cannot
   523   end
   524  end;
   525 
   526 fun solveRtrancl (asms, concl) =
   527  let val (g,tranclEdges) = mkGraph asms
   528      val (_, subgoal, _) = mkconcl_rtrancl concl
   529 in
   530   findRtranclProof g tranclEdges subgoal
   531 end;
   532 
   533 
   534 fun trancl_tac ctxt = SUBGOAL (fn (A, n) => fn st =>
   535  let
   536   val thy = ProofContext.theory_of ctxt;
   537   val Hs = Logic.strip_assums_hyp A;
   538   val C = Logic.strip_assums_concl A;
   539   val (rel, subgoals, prf) = mkconcl_trancl C;
   540 
   541   val prems = flat (ListPair.map (mkasm_trancl rel) (Hs, 0 upto (length Hs - 1)));
   542   val prfs = solveTrancl (prems, C);
   543  in
   544   Subgoal.FOCUS (fn {prems, ...} =>
   545     let val thms = map (prove thy rel prems) prfs
   546     in rtac (prove thy rel thms prf) 1 end) ctxt n st
   547  end
   548  handle Cannot => Seq.empty);
   549 
   550 
   551 fun rtrancl_tac ctxt = SUBGOAL (fn (A, n) => fn st =>
   552  let
   553   val thy = ProofContext.theory_of ctxt;
   554   val Hs = Logic.strip_assums_hyp A;
   555   val C = Logic.strip_assums_concl A;
   556   val (rel, subgoals, prf) = mkconcl_rtrancl C;
   557 
   558   val prems = flat (ListPair.map (mkasm_rtrancl rel) (Hs, 0 upto (length Hs - 1)));
   559   val prfs = solveRtrancl (prems, C);
   560  in
   561   Subgoal.FOCUS (fn {prems, ...} =>
   562     let val thms = map (prove thy rel prems) prfs
   563     in rtac (prove thy rel thms prf) 1 end) ctxt n st
   564  end
   565  handle Cannot => Seq.empty | Subscript => Seq.empty);
   566 
   567 end;