src/Pure/General/pretty.scala
changeset 36817 ed97e877ff2d
parent 36763 096ebe74aeaf
child 36818 599466689412
--- a/src/Pure/General/pretty.scala	Tue May 11 23:09:49 2010 +0200
+++ b/src/Pure/General/pretty.scala	Tue May 11 23:36:06 2010 +0200
@@ -45,32 +45,6 @@
 
   /* formatted output */
 
-  private case class Text(tx: List[XML.Tree] = Nil, val pos: Int = 0, val nl: Int = 0)
-  {
-    def newline: Text = copy(tx = FBreak :: tx, pos = 0, nl = nl + 1)
-    def string(s: String): Text = copy(tx = XML.Text(s) :: tx, pos = pos + s.length)
-    def blanks(wd: Int): Text = string(Symbol.spaces(wd))
-    def content: List[XML.Tree] = tx.reverse
-  }
-
-  private def breakdist(trees: List[XML.Tree], after: Int): Int =
-    trees match {
-      case Break(_) :: _ => 0
-      case FBreak :: _ => 0
-      case XML.Elem(_, _, body) :: ts =>
-        (0 /: body)(_ + XML.content_length(_)) + breakdist(ts, after)
-      case XML.Text(s) :: ts => s.length + breakdist(ts, after)
-      case Nil => after
-    }
-
-  private def forcenext(trees: List[XML.Tree]): List[XML.Tree] =
-    trees match {
-      case Nil => Nil
-      case FBreak :: _ => trees
-      case Break(_) :: ts => FBreak :: ts
-      case t :: ts => t :: forcenext(ts)
-    }
-
   private def standard_format(tree: XML.Tree): List[XML.Tree] =
     tree match {
       case XML.Elem(name, atts, body) => List(XML.Elem(name, atts, body.flatMap(standard_format)))
@@ -79,14 +53,47 @@
           Library.chunks(text).toList.map((s: CharSequence) => XML.Text(s.toString)))
     }
 
+  case class Text(tx: List[XML.Tree] = Nil, val pos: Double = 0.0, val nl: Int = 0)
+  {
+    def newline: Text = copy(tx = FBreak :: tx, pos = 0.0, nl = nl + 1)
+    def string(s: String, len: Double): Text = copy(tx = XML.Text(s) :: tx, pos = pos + len)
+    def blanks(wd: Int): Text = string(Symbol.spaces(wd), wd.toDouble)
+    def content: List[XML.Tree] = tx.reverse
+  }
+
   private val margin_default = 76
 
-  def formatted(input: List[XML.Tree], margin: Int = margin_default): List[XML.Tree] =
+  def formatted(input: List[XML.Tree], margin: Int = margin_default,
+    metric: String => Double = (_.length.toDouble)): List[XML.Tree] =
   {
     val breakgain = margin / 20
     val emergencypos = margin / 2
 
-    def format(trees: List[XML.Tree], blockin: Int, after: Int, text: Text): Text =
+    def content_length(tree: XML.Tree): Double =
+      tree match {
+        case XML.Elem(_, _, body) => (0.0 /: body)(_ + content_length(_))
+        case XML.Text(s) => metric(s)
+      }
+
+    def breakdist(trees: List[XML.Tree], after: Double): Double =
+      trees match {
+        case Break(_) :: _ => 0.0
+        case FBreak :: _ => 0.0
+        case XML.Elem(_, _, body) :: ts =>
+          (0.0 /: body)(_ + content_length(_)) + breakdist(ts, after)
+        case XML.Text(s) :: ts => metric(s) + breakdist(ts, after)
+        case Nil => after
+      }
+
+    def forcenext(trees: List[XML.Tree]): List[XML.Tree] =
+      trees match {
+        case Nil => Nil
+        case FBreak :: _ => trees
+        case Break(_) :: ts => FBreak :: ts
+        case t :: ts => t :: forcenext(ts)
+      }
+
+    def format(trees: List[XML.Tree], blockin: Double, after: Double, text: Text): Text =
       trees match {
         case Nil => text
 
@@ -103,17 +110,17 @@
         case Break(wd) :: ts =>
           if (text.pos + wd <= (margin - breakdist(ts, after)).max(blockin + breakgain))
             format(ts, blockin, after, text.blanks(wd))
-          else format(ts, blockin, after, text.newline.blanks(blockin))
-        case FBreak :: ts => format(ts, blockin, after, text.newline.blanks(blockin))
+          else format(ts, blockin, after, text.newline.blanks(blockin.toInt))
+        case FBreak :: ts => format(ts, blockin, after, text.newline.blanks(blockin.toInt))
 
         case XML.Elem(name, atts, body) :: ts =>
           val btext = format(body, blockin, breakdist(ts, after), text.copy(tx = Nil))
           val ts1 = if (text.nl < btext.nl) forcenext(ts) else ts
           val btext1 = btext.copy(tx = XML.Elem(name, atts, btext.content) :: text.tx)
           format(ts1, blockin, after, btext1)
-        case XML.Text(s) :: ts => format(ts, blockin, after, text.string(s))
+        case XML.Text(s) :: ts => format(ts, blockin, after, text.string(s, metric(s)))
       }
-    format(input.flatMap(standard_format), 0, 0, Text()).content
+    format(input.flatMap(standard_format), 0.0, 0.0, Text()).content
   }
 
   def string_of(input: List[XML.Tree], margin: Int = margin_default): String =
@@ -128,7 +135,7 @@
       tree match {
         case Block(_, body) => body.flatMap(fmt)
         case Break(wd) => List(XML.Text(Symbol.spaces(wd)))
-        case FBreak => List(XML.Text(" "))
+        case FBreak => List(XML.Text(Symbol.space))
         case XML.Elem(name, atts, body) => List(XML.Elem(name, atts, body.flatMap(fmt)))
         case XML.Text(_) => List(tree)
       }