src/Provers/trancl.ML
author wenzelm
Mon Jan 08 16:45:30 2018 +0100 (21 months ago)
changeset 67379 c2dfc510a38c
parent 60785 c6f96559e032
permissions -rw-r--r--
prefer qualified names;
     1 (*  Title:      Provers/trancl.ML
     2     Author:     Oliver Kutter, TU Muenchen
     3 
     4 Transitivity reasoner for transitive closures of relations
     5 *)
     6 
     7 (*
     8 
     9 The packages provides tactics trancl_tac and rtrancl_tac that prove
    10 goals of the form
    11 
    12    (x,y) : r^+     and     (x,y) : r^* (rtrancl_tac only)
    13 
    14 from premises of the form
    15 
    16    (x,y) : r,     (x,y) : r^+     and     (x,y) : r^* (rtrancl_tac only)
    17 
    18 by reflexivity and transitivity.  The relation r is determined by inspecting
    19 the conclusion.
    20 
    21 The package is implemented as an ML functor and thus not limited to
    22 particular constructs for transitive and reflexive-transitive
    23 closures, neither need relations be represented as sets of pairs.  In
    24 order to instantiate the package for transitive closure only, supply
    25 dummy theorems to the additional rules for reflexive-transitive
    26 closures, and don't use rtrancl_tac!
    27 
    28 *)
    29 
    30 signature TRANCL_ARITH =
    31 sig
    32 
    33   (* theorems for transitive closure *)
    34 
    35   val r_into_trancl : thm
    36       (* (a,b) : r ==> (a,b) : r^+ *)
    37   val trancl_trans : thm
    38       (* [| (a,b) : r^+ ; (b,c) : r^+ |] ==> (a,c) : r^+ *)
    39 
    40   (* additional theorems for reflexive-transitive closure *)
    41 
    42   val rtrancl_refl : thm
    43       (* (a,a): r^* *)
    44   val r_into_rtrancl : thm
    45       (* (a,b) : r ==> (a,b) : r^* *)
    46   val trancl_into_rtrancl : thm
    47       (* (a,b) : r^+ ==> (a,b) : r^* *)
    48   val rtrancl_trancl_trancl : thm
    49       (* [| (a,b) : r^* ; (b,c) : r^+ |] ==> (a,c) : r^+ *)
    50   val trancl_rtrancl_trancl : thm
    51       (* [| (a,b) : r^+ ; (b,c) : r^* |] ==> (a,c) : r^+ *)
    52   val rtrancl_trans : thm
    53       (* [| (a,b) : r^* ; (b,c) : r^* |] ==> (a,c) : r^* *)
    54 
    55   (* decomp: decompose a premise or conclusion
    56 
    57      Returns one of the following:
    58 
    59      NONE if not an instance of a relation,
    60      SOME (x, y, r, s) if instance of a relation, where
    61        x: left hand side argument, y: right hand side argument,
    62        r: the relation,
    63        s: the kind of closure, one of
    64             "r":   the relation itself,
    65             "r^+": transitive closure of the relation,
    66             "r^*": reflexive-transitive closure of the relation
    67   *)
    68 
    69   val decomp: term ->  (term * term * term * string) option
    70 
    71 end;
    72 
    73 signature TRANCL_TAC =
    74 sig
    75   val trancl_tac: Proof.context -> int -> tactic
    76   val rtrancl_tac: Proof.context -> int -> tactic
    77 end;
    78 
    79 functor Trancl_Tac(Cls: TRANCL_ARITH): TRANCL_TAC =
    80 struct
    81 
    82 
    83 datatype proof
    84   = Asm of int
    85   | Thm of proof list * thm;
    86 
    87 exception Cannot; (* internal exception: raised if no proof can be found *)
    88 
    89 fun decomp t = Option.map (fn (x, y, rel, r) =>
    90   (Envir.beta_eta_contract x, Envir.beta_eta_contract y,
    91    Envir.beta_eta_contract rel, r)) (Cls.decomp t);
    92 
    93 fun prove ctxt r asms =
    94   let
    95     fun inst thm =
    96       let val SOME (_, _, Var (r', _), _) = decomp (Thm.concl_of thm)
    97       in infer_instantiate ctxt [(r', Thm.cterm_of ctxt r)] thm end;
    98     fun pr (Asm i) = nth asms i
    99       | pr (Thm (prfs, thm)) = map pr prfs MRS inst thm;
   100   in pr end;
   101 
   102 
   103 (* Internal datatype for inequalities *)
   104 datatype rel
   105    = Trans  of term * term * proof  (* R^+ *)
   106    | RTrans of term * term * proof; (* R^* *)
   107 
   108  (* Misc functions for datatype rel *)
   109 fun lower (Trans (x, _, _)) = x
   110   | lower (RTrans (x,_,_)) = x;
   111 
   112 fun upper (Trans (_, y, _)) = y
   113   | upper (RTrans (_,y,_)) = y;
   114 
   115 fun getprf   (Trans   (_, _, p)) = p
   116 |   getprf   (RTrans (_,_, p)) = p;
   117 
   118 (* ************************************************************************ *)
   119 (*                                                                          *)
   120 (*  mkasm_trancl Rel (t,n): term -> (term , int) -> rel list                *)
   121 (*                                                                          *)
   122 (*  Analyse assumption t with index n with respect to relation Rel:         *)
   123 (*  If t is of the form "(x, y) : Rel" (or Rel^+), translate to             *)
   124 (*  an object (singleton list) of internal datatype rel.                    *)
   125 (*  Otherwise return empty list.                                            *)
   126 (*                                                                          *)
   127 (* ************************************************************************ *)
   128 
   129 fun mkasm_trancl  Rel  (t, n) =
   130   case decomp t of
   131     SOME (x, y, rel,r) => if rel aconv Rel then
   132 
   133     (case r of
   134       "r"   => [Trans (x,y, Thm([Asm n], Cls.r_into_trancl))]
   135     | "r+"  => [Trans (x,y, Asm n)]
   136     | "r*"  => []
   137     | _     => error ("trancl_tac: unknown relation symbol"))
   138     else []
   139   | NONE => [];
   140 
   141 (* ************************************************************************ *)
   142 (*                                                                          *)
   143 (*  mkasm_rtrancl Rel (t,n): term -> (term , int) -> rel list               *)
   144 (*                                                                          *)
   145 (*  Analyse assumption t with index n with respect to relation Rel:         *)
   146 (*  If t is of the form "(x, y) : Rel" (or Rel^+ or Rel^* ), translate to   *)
   147 (*  an object (singleton list) of internal datatype rel.                    *)
   148 (*  Otherwise return empty list.                                            *)
   149 (*                                                                          *)
   150 (* ************************************************************************ *)
   151 
   152 fun mkasm_rtrancl Rel (t, n) =
   153   case decomp t of
   154    SOME (x, y, rel, r) => if rel aconv Rel then
   155     (case r of
   156       "r"   => [ Trans (x,y, Thm([Asm n], Cls.r_into_trancl))]
   157     | "r+"  => [ Trans (x,y, Asm n)]
   158     | "r*"  => [ RTrans(x,y, Asm n)]
   159     | _     => error ("rtrancl_tac: unknown relation symbol" ))
   160    else []
   161   | NONE => [];
   162 
   163 (* ************************************************************************ *)
   164 (*                                                                          *)
   165 (*  mkconcl_trancl t: term -> (term, rel, proof)                            *)
   166 (*  mkconcl_rtrancl t: term -> (term, rel, proof)                           *)
   167 (*                                                                          *)
   168 (*  Analyse conclusion t:                                                   *)
   169 (*    - must be of form "(x, y) : r^+ (or r^* for rtrancl)                  *)
   170 (*    - returns r                                                           *)
   171 (*    - conclusion in internal form                                         *)
   172 (*    - proof object                                                        *)
   173 (*                                                                          *)
   174 (* ************************************************************************ *)
   175 
   176 fun mkconcl_trancl  t =
   177   case decomp t of
   178     SOME (x, y, rel, r) => (case r of
   179       "r+"  => (rel, Trans (x,y, Asm ~1), Asm 0)
   180     | _     => raise Cannot)
   181   | NONE => raise Cannot;
   182 
   183 fun mkconcl_rtrancl  t =
   184   case decomp t of
   185     SOME (x,  y, rel,r ) => (case r of
   186       "r+"  => (rel, Trans (x,y, Asm ~1),  Asm 0)
   187     | "r*"  => (rel, RTrans (x,y, Asm ~1), Asm 0)
   188     | _     => raise Cannot)
   189   | NONE => raise Cannot;
   190 
   191 (* ************************************************************************ *)
   192 (*                                                                          *)
   193 (*  makeStep (r1, r2): rel * rel -> rel                                     *)
   194 (*                                                                          *)
   195 (*  Apply transitivity to r1 and r2, obtaining a new element of r^+ or r^*, *)
   196 (*  according the following rules:                                          *)
   197 (*                                                                          *)
   198 (* ( (a, b) : r^+ , (b,c) : r^+ ) --> (a,c) : r^+                           *)
   199 (* ( (a, b) : r^* , (b,c) : r^+ ) --> (a,c) : r^+                           *)
   200 (* ( (a, b) : r^+ , (b,c) : r^* ) --> (a,c) : r^+                           *)
   201 (* ( (a, b) : r^* , (b,c) : r^* ) --> (a,c) : r^*                           *)
   202 (*                                                                          *)
   203 (* ************************************************************************ *)
   204 
   205 fun makeStep (Trans (a,_,p), Trans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.trancl_trans))
   206 (* refl. + trans. cls. rules *)
   207 |   makeStep (RTrans (a,_,p), Trans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.rtrancl_trancl_trancl))
   208 |   makeStep (Trans (a,_,p), RTrans(_,c,q))  = Trans (a,c, Thm ([p,q], Cls.trancl_rtrancl_trancl))
   209 |   makeStep (RTrans (a,_,p), RTrans(_,c,q))  = RTrans (a,c, Thm ([p,q], Cls.rtrancl_trans));
   210 
   211 (* ******************************************************************* *)
   212 (*                                                                     *)
   213 (* transPath (Clslist, Cls): (rel  list * rel) -> rel                  *)
   214 (*                                                                     *)
   215 (* If a path represented by a list of elements of type rel is found,   *)
   216 (* this needs to be contracted to a single element of type rel.        *)
   217 (* Prior to each transitivity step it is checked whether the step is   *)
   218 (* valid.                                                              *)
   219 (*                                                                     *)
   220 (* ******************************************************************* *)
   221 
   222 fun transPath ([],acc) = acc
   223 |   transPath (x::xs,acc) = transPath (xs, makeStep(acc,x))
   224 
   225 (* ********************************************************************* *)
   226 (* Graph functions                                                       *)
   227 (* ********************************************************************* *)
   228 
   229 (* *********************************************************** *)
   230 (* Functions for constructing graphs                           *)
   231 (* *********************************************************** *)
   232 
   233 fun addEdge (v,d,[]) = [(v,d)]
   234 |   addEdge (v,d,((u,dl)::el)) = if v aconv u then ((v,d@dl)::el)
   235     else (u,dl):: (addEdge(v,d,el));
   236 
   237 (* ********************************************************************** *)
   238 (*                                                                        *)
   239 (* mkGraph constructs from a list of objects of type rel  a graph g       *)
   240 (* and a list of all edges with label r+.                                 *)
   241 (*                                                                        *)
   242 (* ********************************************************************** *)
   243 
   244 fun mkGraph [] = ([],[])
   245 |   mkGraph ys =
   246  let
   247   fun buildGraph ([],g,zs) = (g,zs)
   248   |   buildGraph (x::xs, g, zs) =
   249         case x of (Trans (_,_,_)) =>
   250                buildGraph (xs, addEdge((upper x), [],(addEdge ((lower x),[((upper x),x)],g))), x::zs)
   251         | _ => buildGraph (xs, addEdge((upper x), [],(addEdge ((lower x),[((upper x),x)],g))), zs)
   252 in buildGraph (ys, [], []) end;
   253 
   254 (* *********************************************************************** *)
   255 (*                                                                         *)
   256 (* adjacent g u : (''a * 'b list ) list -> ''a -> 'b list                  *)
   257 (*                                                                         *)
   258 (* List of successors of u in graph g                                      *)
   259 (*                                                                         *)
   260 (* *********************************************************************** *)
   261 
   262 fun adjacent eq_comp ((v,adj)::el) u =
   263     if eq_comp (u, v) then adj else adjacent eq_comp el u
   264 |   adjacent _  []  _ = []
   265 
   266 (* *********************************************************************** *)
   267 (*                                                                         *)
   268 (* dfs eq_comp g u v:                                                      *)
   269 (* ('a * 'a -> bool) -> ('a  *( 'a * rel) list) list ->                    *)
   270 (* 'a -> 'a -> (bool * ('a * rel) list)                                    *)
   271 (*                                                                         *)
   272 (* Depth first search of v from u.                                         *)
   273 (* Returns (true, path(u, v)) if successful, otherwise (false, []).        *)
   274 (*                                                                         *)
   275 (* *********************************************************************** *)
   276 
   277 fun dfs eq_comp g u v =
   278  let
   279     val pred = Unsynchronized.ref [];
   280     val visited = Unsynchronized.ref [];
   281 
   282     fun been_visited v = exists (fn w => eq_comp (w, v)) (!visited)
   283 
   284     fun dfs_visit u' =
   285     let val _ = visited := u' :: (!visited)
   286 
   287     fun update (x,l) = let val _ = pred := (x,l) ::(!pred) in () end;
   288 
   289     in if been_visited v then ()
   290     else (List.app (fn (v',l) => if been_visited v' then () else (
   291        update (v',l);
   292        dfs_visit v'; ()) )) (adjacent eq_comp g u')
   293      end
   294   in
   295     dfs_visit u;
   296     if (been_visited v) then (true, (!pred)) else (false , [])
   297   end;
   298 
   299 (* *********************************************************************** *)
   300 (*                                                                         *)
   301 (* transpose g:                                                            *)
   302 (* (''a * ''a list) list -> (''a * ''a list) list                          *)
   303 (*                                                                         *)
   304 (* Computes transposed graph g' from g                                     *)
   305 (* by reversing all edges u -> v to v -> u                                 *)
   306 (*                                                                         *)
   307 (* *********************************************************************** *)
   308 
   309 fun transpose eq_comp g =
   310   let
   311    (* Compute list of reversed edges for each adjacency list *)
   312    fun flip (u,(v,l)::el) = (v,(u,l)) :: flip (u,el)
   313      | flip (_,[]) = []
   314 
   315    (* Compute adjacency list for node u from the list of edges
   316       and return a likewise reduced list of edges.  The list of edges
   317       is searches for edges starting from u, and these edges are removed. *)
   318    fun gather (u,(v,w)::el) =
   319     let
   320      val (adj,edges) = gather (u,el)
   321     in
   322      if eq_comp (u, v) then (w::adj,edges)
   323      else (adj,(v,w)::edges)
   324     end
   325    | gather (_,[]) = ([],[])
   326 
   327    (* For every node in the input graph, call gather to find all reachable
   328       nodes in the list of edges *)
   329    fun assemble ((u,_)::el) edges =
   330        let val (adj,edges) = gather (u,edges)
   331        in (u,adj) :: assemble el edges
   332        end
   333      | assemble [] _ = []
   334 
   335    (* Compute, for each adjacency list, the list with reversed edges,
   336       and concatenate these lists. *)
   337    val flipped = maps flip g
   338 
   339  in assemble g flipped end
   340 
   341 (* *********************************************************************** *)
   342 (*                                                                         *)
   343 (* dfs_reachable eq_comp g u:                                              *)
   344 (* (int * int list) list -> int -> int list                                *)
   345 (*                                                                         *)
   346 (* Computes list of all nodes reachable from u in g.                       *)
   347 (*                                                                         *)
   348 (* *********************************************************************** *)
   349 
   350 fun dfs_reachable eq_comp g u =
   351  let
   352   (* List of vertices which have been visited. *)
   353   val visited  = Unsynchronized.ref [];
   354 
   355   fun been_visited v = exists (fn w => eq_comp (w, v)) (!visited)
   356 
   357   fun dfs_visit g u  =
   358       let
   359    val _ = visited := u :: !visited
   360    val descendents =
   361        List.foldr (fn ((v,_),ds) => if been_visited v then ds
   362             else v :: dfs_visit g v @ ds)
   363         [] (adjacent eq_comp g u)
   364    in  descendents end
   365 
   366  in u :: dfs_visit g u end;
   367 
   368 (* *********************************************************************** *)
   369 (*                                                                         *)
   370 (* dfs_term_reachable g u:                                                  *)
   371 (* (term * term list) list -> term -> term list                            *)
   372 (*                                                                         *)
   373 (* Computes list of all nodes reachable from u in g.                       *)
   374 (*                                                                         *)
   375 (* *********************************************************************** *)
   376 
   377 fun dfs_term_reachable g u = dfs_reachable (op aconv) g u;
   378 
   379 (* ************************************************************************ *)
   380 (*                                                                          *)
   381 (* findPath x y g: Term.term -> Term.term ->                                *)
   382 (*                  (Term.term * (Term.term * rel list) list) ->            *)
   383 (*                  (bool, rel list)                                        *)
   384 (*                                                                          *)
   385 (*  Searches a path from vertex x to vertex y in Graph g, returns true and  *)
   386 (*  the list of edges if path is found, otherwise false and nil.            *)
   387 (*                                                                          *)
   388 (* ************************************************************************ *)
   389 
   390 fun findPath x y g =
   391   let
   392    val (found, tmp) =  dfs (op aconv) g x y ;
   393    val pred = map snd tmp;
   394 
   395    fun path x y  =
   396     let
   397          (* find predecessor u of node v and the edge u -> v *)
   398 
   399       fun lookup v [] = raise Cannot
   400       |   lookup v (e::es) = if (upper e) aconv v then e else lookup v es;
   401 
   402       (* traverse path backwards and return list of visited edges *)
   403       fun rev_path v =
   404         let val l = lookup v pred
   405             val u = lower l;
   406         in
   407           if u aconv x then [l] else (rev_path u) @ [l]
   408         end
   409 
   410     in rev_path y end;
   411 
   412    in
   413 
   414 
   415       if found then ( (found, (path x y) )) else (found,[])
   416 
   417 
   418 
   419    end;
   420 
   421 (* ************************************************************************ *)
   422 (*                                                                          *)
   423 (* findRtranclProof g tranclEdges subgoal:                                  *)
   424 (* (Term.term * (Term.term * rel list) list) -> rel -> proof list           *)
   425 (*                                                                          *)
   426 (* Searches in graph g a proof for subgoal.                                 *)
   427 (*                                                                          *)
   428 (* ************************************************************************ *)
   429 
   430 fun findRtranclProof g tranclEdges subgoal =
   431    case subgoal of (RTrans (x,y,_)) => if x aconv y then [Thm ([], Cls.rtrancl_refl)] else (
   432      let val (found, path) = findPath (lower subgoal) (upper subgoal) g
   433      in
   434        if found then (
   435           let val path' = (transPath (tl path, hd path))
   436           in
   437 
   438             case path' of (Trans (_,_,p)) => [Thm ([p], Cls.trancl_into_rtrancl )]
   439             | _ => [getprf path']
   440 
   441           end
   442        )
   443        else raise Cannot
   444      end
   445    )
   446 
   447 | (Trans (x,y,_)) => (
   448 
   449   let
   450    val Vx = dfs_term_reachable g x;
   451    val g' = transpose (op aconv) g;
   452    val Vy = dfs_term_reachable g' y;
   453 
   454    fun processTranclEdges [] = raise Cannot
   455    |   processTranclEdges (e::es) =
   456           if member (op =) Vx (upper e) andalso member (op =) Vx (lower e)
   457           andalso member (op =) Vy (upper e) andalso member (op =) Vy (lower e)
   458           then (
   459 
   460 
   461             if (lower e) aconv x then (
   462               if (upper e) aconv y then (
   463                   [(getprf e)]
   464               )
   465               else (
   466                   let
   467                     val (found,path) = findPath (upper e) y g
   468                   in
   469 
   470                    if found then (
   471                        [getprf (transPath (path, e))]
   472                       ) else processTranclEdges es
   473 
   474                   end
   475               )
   476             )
   477             else if (upper e) aconv y then (
   478                let val (xufound,xupath) = findPath x (lower e) g
   479                in
   480 
   481                   if xufound then (
   482 
   483                     let val xuRTranclEdge = transPath (tl xupath, hd xupath)
   484                             val xyTranclEdge = makeStep(xuRTranclEdge,e)
   485 
   486                                 in [getprf xyTranclEdge] end
   487 
   488                  ) else processTranclEdges es
   489 
   490                end
   491             )
   492             else (
   493 
   494                 let val (xufound,xupath) = findPath x (lower e) g
   495                     val (vyfound,vypath) = findPath (upper e) y g
   496                  in
   497                     if xufound then (
   498                          if vyfound then (
   499                             let val xuRTranclEdge = transPath (tl xupath, hd xupath)
   500                                 val vyRTranclEdge = transPath (tl vypath, hd vypath)
   501                                 val xyTranclEdge = makeStep (makeStep(xuRTranclEdge,e),vyRTranclEdge)
   502 
   503                                 in [getprf xyTranclEdge] end
   504 
   505                          ) else processTranclEdges es
   506                     )
   507                     else processTranclEdges es
   508                  end
   509             )
   510           )
   511           else processTranclEdges es;
   512    in processTranclEdges tranclEdges end )
   513 
   514 
   515 fun solveTrancl (asms, concl) =
   516  let val (g,_) = mkGraph asms
   517  in
   518   let val (_, subgoal, _) = mkconcl_trancl concl
   519       val (found, path) = findPath (lower subgoal) (upper subgoal) g
   520   in
   521     if found then  [getprf (transPath (tl path, hd path))]
   522     else raise Cannot
   523   end
   524  end;
   525 
   526 fun solveRtrancl (asms, concl) =
   527  let val (g,tranclEdges) = mkGraph asms
   528      val (_, subgoal, _) = mkconcl_rtrancl concl
   529 in
   530   findRtranclProof g tranclEdges subgoal
   531 end;
   532 
   533 
   534 fun trancl_tac ctxt = SUBGOAL (fn (A, n) => fn st =>
   535  let
   536   val Hs = Logic.strip_assums_hyp A;
   537   val C = Logic.strip_assums_concl A;
   538   val (rel, _, prf) = mkconcl_trancl C;
   539 
   540   val prems = flat (map_index (mkasm_trancl rel o swap) Hs);
   541   val prfs = solveTrancl (prems, C);
   542  in
   543   Subgoal.FOCUS (fn {context = ctxt', prems, concl, ...} =>
   544     let
   545       val SOME (_, _, rel', _) = decomp (Thm.term_of concl);
   546       val thms = map (prove ctxt' rel' prems) prfs
   547     in resolve_tac ctxt' [prove ctxt' rel' thms prf] 1 end) ctxt n st
   548  end
   549  handle Cannot => Seq.empty);
   550 
   551 
   552 fun rtrancl_tac ctxt = SUBGOAL (fn (A, n) => fn st =>
   553  let
   554   val Hs = Logic.strip_assums_hyp A;
   555   val C = Logic.strip_assums_concl A;
   556   val (rel, _, prf) = mkconcl_rtrancl C;
   557 
   558   val prems = flat (map_index (mkasm_rtrancl rel o swap) Hs);
   559   val prfs = solveRtrancl (prems, C);
   560  in
   561   Subgoal.FOCUS (fn {context = ctxt', prems, concl, ...} =>
   562     let
   563       val SOME (_, _, rel', _) = decomp (Thm.term_of concl);
   564       val thms = map (prove ctxt' rel' prems) prfs
   565     in resolve_tac ctxt' [prove ctxt' rel' thms prf] 1 end) ctxt n st
   566  end
   567  handle Cannot => Seq.empty | General.Subscript => Seq.empty);
   568 
   569 end;