more uniform/robust detect_repository/is_repository: actually check hg root;
authorwenzelm
Fri, 24 May 2024 16:15:27 +0200
changeset 80188 3956e8b6a9c9
parent 80187 b8918a5a669e
child 80189 e8d4ac2f21ea
more uniform/robust detect_repository/is_repository: actually check hg root;
src/Pure/General/mercurial.scala
--- a/src/Pure/General/mercurial.scala	Fri May 24 15:55:34 2024 +0200
+++ b/src/Pure/General/mercurial.scala	Fri May 24 16:15:27 2024 +0200
@@ -153,26 +153,31 @@
 
   /* repository access */
 
+  def detect_repository(root: Path, ssh: SSH.System = SSH.Local): Option[Repository] =
+    if (ssh.is_dir(root + Path.explode(".hg"))) {
+      val hg = new Repository(root, ssh)
+      val result = hg.command("root")
+      if (result.ok && ssh.eq_file(root, Path.explode(result.out))) Some(hg) else None
+    }
+    else None
+
   def is_repository(root: Path, ssh: SSH.System = SSH.Local): Boolean =
-    ssh.is_dir(root + Path.explode(".hg")) &&
-    new Repository(root, ssh).command("root").ok
+    detect_repository(root, ssh = ssh).isDefined
 
   def id_repository(root: Path, ssh: SSH.System = SSH.Local, rev: String = "tip"): Option[String] =
-    if (is_repository(root, ssh = ssh)) Some(repository(root, ssh = ssh).id(rev = rev)) else None
+    for (hg <- detect_repository(root, ssh = ssh)) yield hg.id(rev = rev)
 
-  def repository(root: Path, ssh: SSH.System = SSH.Local): Repository = {
-    val hg = new Repository(root, ssh)
-    hg.command("root").check
-    hg
-  }
+  def repository(root: Path, ssh: SSH.System = SSH.Local): Repository =
+    detect_repository(root, ssh = ssh) getOrElse error("Bad hg repository " + root.expand)
 
   def self_repository(): Repository = repository(Path.ISABELLE_HOME)
 
   def find_repository(start: Path, ssh: SSH.System = SSH.Local): Option[Repository] = {
     @tailrec def find(root: Path): Option[Repository] =
-      if (is_repository(root, ssh)) Some(repository(root, ssh = ssh))
-      else if (root.is_root) None
-      else find(root + Path.parent)
+      detect_repository(root, ssh = ssh) match {
+        case None => if (root.is_root) None else find(root + Path.parent)
+        case some => some
+      }
 
     find(ssh.expand_path(start))
   }
@@ -459,9 +464,7 @@
     val repos_name =
       proper_string(remote_name) getOrElse local_path.absolute.base.implode
 
-    val local_hg =
-      if (is_repository(local_path)) repository(local_path)
-      else init_repository(local_path)
+    val local_hg = detect_repository(local_path) getOrElse init_repository(local_path)
 
     progress.echo("Local repository " + local_hg.root.absolute)