src/Tools/Metis/src/Term.sml
author blanchet
Mon, 13 Sep 2010 21:09:43 +0200
changeset 39348 6f9c9899f99f
child 39349 2d0a4361c3ef
permissions -rw-r--r--
new version of the Metis files

(* ========================================================================= *)
(* FIRST ORDER LOGIC TERMS                                                   *)
(* Copyright (c) 2001 Joe Hurd, distributed under the GNU GPL version 2      *)
(* ========================================================================= *)

structure Term :> Term =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of first order logic terms.                                        *)
(* ------------------------------------------------------------------------- *)

type var = Name.name;

type functionName = Name.name;

type function = functionName * int;

type const = functionName;

datatype term =
    Var of Name.name
  | Fn of Name.name * term list;

(* ------------------------------------------------------------------------- *)
(* Constructors and destructors.                                             *)
(* ------------------------------------------------------------------------- *)

(* Variables *)

fun destVar (Var v) = v
  | destVar (Fn _) = raise Error "destVar";

val isVar = can destVar;

fun equalVar v (Var v') = Name.equal v v'
  | equalVar _ _ = false;

(* Functions *)

fun destFn (Fn f) = f
  | destFn (Var _) = raise Error "destFn";

val isFn = can destFn;

fun fnName tm = fst (destFn tm);

fun fnArguments tm = snd (destFn tm);

fun fnArity tm = length (fnArguments tm);

fun fnFunction tm = (fnName tm, fnArity tm);

local
  fun func fs [] = fs
    | func fs (Var _ :: tms) = func fs tms
    | func fs (Fn (n,l) :: tms) =
      func (NameAritySet.add fs (n, length l)) (l @ tms);
in
  fun functions tm = func NameAritySet.empty [tm];
end;

local
  fun func fs [] = fs
    | func fs (Var _ :: tms) = func fs tms
    | func fs (Fn (n,l) :: tms) = func (NameSet.add fs n) (l @ tms);
in
  fun functionNames tm = func NameSet.empty [tm];
end;

(* Constants *)

fun mkConst c = (Fn (c, []));

fun destConst (Fn (c, [])) = c
  | destConst _ = raise Error "destConst";

val isConst = can destConst;

(* Binary functions *)

fun mkBinop f (a,b) = Fn (f,[a,b]);

fun destBinop f (Fn (x,[a,b])) =
    if Name.equal x f then (a,b) else raise Error "Term.destBinop: wrong binop"
  | destBinop _ _ = raise Error "Term.destBinop: not a binop";

fun isBinop f = can (destBinop f);

(* ------------------------------------------------------------------------- *)
(* The size of a term in symbols.                                            *)
(* ------------------------------------------------------------------------- *)

val VAR_SYMBOLS = 1;

val FN_SYMBOLS = 1;

local
  fun sz n [] = n
    | sz n (Var _ :: tms) = sz (n + VAR_SYMBOLS) tms
    | sz n (Fn (func,args) :: tms) = sz (n + FN_SYMBOLS) (args @ tms);
in
  fun symbols tm = sz 0 [tm];
end;

(* ------------------------------------------------------------------------- *)
(* A total comparison function for terms.                                    *)
(* ------------------------------------------------------------------------- *)

local
  fun cmp [] [] = EQUAL
    | cmp (tm1 :: tms1) (tm2 :: tms2) =
      let
        val tm1_tm2 = (tm1,tm2)
      in
        if Portable.pointerEqual tm1_tm2 then cmp tms1 tms2
        else
          case tm1_tm2 of
            (Var v1, Var v2) =>
            (case Name.compare (v1,v2) of
               LESS => LESS
             | EQUAL => cmp tms1 tms2
             | GREATER => GREATER)
          | (Var _, Fn _) => LESS
          | (Fn _, Var _) => GREATER
          | (Fn (f1,a1), Fn (f2,a2)) =>
            (case Name.compare (f1,f2) of
               LESS => LESS
             | EQUAL =>
               (case Int.compare (length a1, length a2) of
                  LESS => LESS
                | EQUAL => cmp (a1 @ tms1) (a2 @ tms2)
                | GREATER => GREATER)
             | GREATER => GREATER)
      end
    | cmp _ _ = raise Bug "Term.compare";
in
  fun compare (tm1,tm2) = cmp [tm1] [tm2];
end;

fun equal tm1 tm2 = compare (tm1,tm2) = EQUAL;

(* ------------------------------------------------------------------------- *)
(* Subterms.                                                                 *)
(* ------------------------------------------------------------------------- *)

type path = int list;

fun subterm tm [] = tm
  | subterm (Var _) (_ :: _) = raise Error "Term.subterm: Var"
  | subterm (Fn (_,tms)) (h :: t) =
    if h >= length tms then raise Error "Term.replace: Fn"
    else subterm (List.nth (tms,h)) t;

local
  fun subtms [] acc = acc
    | subtms ((path,tm) :: rest) acc =
      let
        fun f (n,arg) = (n :: path, arg)

        val acc = (rev path, tm) :: acc
      in
        case tm of
          Var _ => subtms rest acc
        | Fn (_,args) => subtms (map f (enumerate args) @ rest) acc
      end;
in
  fun subterms tm = subtms [([],tm)] [];
end;

fun replace tm ([],res) = if equal res tm then tm else res
  | replace tm (h :: t, res) =
    case tm of
      Var _ => raise Error "Term.replace: Var"
    | Fn (func,tms) =>
      if h >= length tms then raise Error "Term.replace: Fn"
      else
        let
          val arg = List.nth (tms,h)
          val arg' = replace arg (t,res)
        in
          if Portable.pointerEqual (arg',arg) then tm
          else Fn (func, updateNth (h,arg') tms)
        end;

fun find pred =
    let
      fun search [] = NONE
        | search ((path,tm) :: rest) =
          if pred tm then SOME (rev path)
          else
            case tm of
              Var _ => search rest
            | Fn (_,a) =>
              let
                val subtms = map (fn (i,t) => (i :: path, t)) (enumerate a)
              in
                search (subtms @ rest)
              end
    in
      fn tm => search [([],tm)]
    end;

val ppPath = Print.ppList Print.ppInt;

val pathToString = Print.toString ppPath;

(* ------------------------------------------------------------------------- *)
(* Free variables.                                                           *)
(* ------------------------------------------------------------------------- *)

local
  fun free _ [] = false
    | free v (Var w :: tms) = Name.equal v w orelse free v tms
    | free v (Fn (_,args) :: tms) = free v (args @ tms);
in
  fun freeIn v tm = free v [tm];
end;

local
  fun free vs [] = vs
    | free vs (Var v :: tms) = free (NameSet.add vs v) tms
    | free vs (Fn (_,args) :: tms) = free vs (args @ tms);
in
  val freeVarsList = free NameSet.empty;

  fun freeVars tm = freeVarsList [tm];
end;

(* ------------------------------------------------------------------------- *)
(* Fresh variables.                                                          *)
(* ------------------------------------------------------------------------- *)

fun newVar () = Var (Name.newName ());

fun newVars n = map Var (Name.newNames n);

local
  fun avoidAcceptable avoid n = not (NameSet.member n avoid);
in
  fun variantPrime avoid = Name.variantPrime (avoidAcceptable avoid);

  fun variantNum avoid = Name.variantNum (avoidAcceptable avoid);
end;

(* ------------------------------------------------------------------------- *)
(* Special support for terms with type annotations.                          *)
(* ------------------------------------------------------------------------- *)

val hasTypeFunctionName = Name.fromString ":";

val hasTypeFunction = (hasTypeFunctionName,2);

fun destFnHasType ((f,a) : functionName * term list) =
    if not (Name.equal f hasTypeFunctionName) then
      raise Error "Term.destFnHasType"
    else
      case a of
        [tm,ty] => (tm,ty)
      | _ => raise Error "Term.destFnHasType";

val isFnHasType = can destFnHasType;

fun isTypedVar tm =
    case tm of
      Var _ => true
    | Fn func =>
      case total destFnHasType func of
        SOME (Var _, _) => true
      | _ => false;

local
  fun sz n [] = n
    | sz n (tm :: tms) =
      case tm of
        Var _ => sz (n + 1) tms
      | Fn func =>
        case total destFnHasType func of
          SOME (tm,_) => sz n (tm :: tms)
        | NONE =>
          let
            val (_,a) = func
          in
            sz (n + 1) (a @ tms)
          end;
in
  fun typedSymbols tm = sz 0 [tm];
end;

local
  fun subtms [] acc = acc
    | subtms ((path,tm) :: rest) acc =
      case tm of
        Var _ => subtms rest acc
      | Fn func =>
        case total destFnHasType func of
          SOME (t,_) =>
          (case t of
             Var _ => subtms rest acc
           | Fn _ =>
             let
               val acc = (rev path, tm) :: acc
               val rest = (0 :: path, t) :: rest
             in
               subtms rest acc
             end)
        | NONE =>
          let
            fun f (n,arg) = (n :: path, arg)

            val (_,args) = func

            val acc = (rev path, tm) :: acc

            val rest = map f (enumerate args) @ rest
          in
            subtms rest acc
          end;
in
  fun nonVarTypedSubterms tm = subtms [([],tm)] [];
end;

(* ------------------------------------------------------------------------- *)
(* Special support for terms with an explicit function application operator. *)
(* ------------------------------------------------------------------------- *)

val appName = Name.fromString ".";

fun mkFnApp (fTm,aTm) = (appName, [fTm,aTm]);

fun mkApp f_a = Fn (mkFnApp f_a);

fun destFnApp ((f,a) : Name.name * term list) =
    if not (Name.equal f appName) then raise Error "Term.destFnApp"
    else
      case a of
        [fTm,aTm] => (fTm,aTm)
      | _ => raise Error "Term.destFnApp";

val isFnApp = can destFnApp;

fun destApp tm =
    case tm of
      Var _ => raise Error "Term.destApp"
    | Fn func => destFnApp func;

val isApp = can destApp;

fun listMkApp (f,l) = foldl mkApp f l;

local
  fun strip tms tm =
      case total destApp tm of
        SOME (f,a) => strip (a :: tms) f
      | NONE => (tm,tms);
in
  fun stripApp tm = strip [] tm;
end;

(* ------------------------------------------------------------------------- *)
(* Parsing and pretty printing.                                              *)
(* ------------------------------------------------------------------------- *)

(* Operators parsed and printed infix *)

val infixes =
    (ref o Print.Infixes)
      [(* ML symbols *)
       {token = " / ", precedence = 7, leftAssoc = true},
       {token = " div ", precedence = 7, leftAssoc = true},
       {token = " mod ", precedence = 7, leftAssoc = true},
       {token = " * ", precedence = 7, leftAssoc = true},
       {token = " + ", precedence = 6, leftAssoc = true},
       {token = " - ", precedence = 6, leftAssoc = true},
       {token = " ^ ", precedence = 6, leftAssoc = true},
       {token = " @ ", precedence = 5, leftAssoc = false},
       {token = " :: ", precedence = 5, leftAssoc = false},
       {token = " = ", precedence = 4, leftAssoc = true},
       {token = " <> ", precedence = 4, leftAssoc = true},
       {token = " <= ", precedence = 4, leftAssoc = true},
       {token = " < ", precedence = 4, leftAssoc = true},
       {token = " >= ", precedence = 4, leftAssoc = true},
       {token = " > ", precedence = 4, leftAssoc = true},
       {token = " o ", precedence = 3, leftAssoc = true},
       {token = " -> ", precedence = 2, leftAssoc = false},  (* inferred prec *)
       {token = " : ", precedence = 1, leftAssoc = false},  (* inferred prec *)
       {token = ", ", precedence = 0, leftAssoc = false},  (* inferred prec *)

       (* Logical connectives *)
       {token = " /\\ ", precedence = ~1, leftAssoc = false},
       {token = " \\/ ", precedence = ~2, leftAssoc = false},
       {token = " ==> ", precedence = ~3, leftAssoc = false},
       {token = " <=> ", precedence = ~4, leftAssoc = false},

       (* Other symbols *)
       {token = " . ", precedence = 9, leftAssoc = true},  (* function app *)
       {token = " ** ", precedence = 8, leftAssoc = true},
       {token = " ++ ", precedence = 6, leftAssoc = true},
       {token = " -- ", precedence = 6, leftAssoc = true},
       {token = " == ", precedence = 4, leftAssoc = true}];

(* The negation symbol *)

val negation : string ref = ref "~";

(* Binder symbols *)

val binders : string list ref = ref ["\\","!","?","?!"];

(* Bracket symbols *)

val brackets : (string * string) list ref = ref [("[","]"),("{","}")];

(* Pretty printing *)

fun pp inputTerm =
    let
      val quants = !binders
      and iOps = !infixes
      and neg = !negation
      and bracks = !brackets

      val bracks = map (fn (b1,b2) => (b1 ^ b2, b1, b2)) bracks

      val bTokens = map #2 bracks @ map #3 bracks

      val iTokens = Print.tokensInfixes iOps

      fun destI tm =
          case tm of
            Fn (f,[a,b]) =>
            let
              val f = Name.toString f
            in
              if StringSet.member f iTokens then SOME (f,a,b) else NONE
            end
          | _ => NONE

      val iPrinter = Print.ppInfixes iOps destI

      val specialTokens =
          StringSet.addList iTokens (neg :: quants @ ["$","(",")"] @ bTokens)

      fun vName bv s = StringSet.member s bv

      fun checkVarName bv n =
          let
            val s = Name.toString n
          in
            if vName bv s then s else "$" ^ s
          end

      fun varName bv = Print.ppMap (checkVarName bv) Print.ppString

      fun checkFunctionName bv n =
          let
            val s = Name.toString n
          in
            if StringSet.member s specialTokens orelse vName bv s then
              "(" ^ s ^ ")"
            else
              s
          end

      fun functionName bv = Print.ppMap (checkFunctionName bv) Print.ppString

      fun isI tm = Option.isSome (destI tm)

      fun stripNeg tm =
          case tm of
            Fn (f,[a]) =>
            if Name.toString f <> neg then (0,tm)
            else let val (n,tm) = stripNeg a in (n + 1, tm) end
          | _ => (0,tm)

      val destQuant =
          let
            fun dest q (Fn (q', [Var v, body])) =
                if Name.toString q' <> q then NONE
                else
                  (case dest q body of
                     NONE => SOME (q,v,[],body)
                   | SOME (_,v',vs,body) => SOME (q, v, v' :: vs, body))
              | dest _ _ = NONE
          in
            fn tm => Useful.first (fn q => dest q tm) quants
          end

      fun isQuant tm = Option.isSome (destQuant tm)

      fun destBrack (Fn (b,[tm])) =
          let
            val s = Name.toString b
          in
            case List.find (fn (n,_,_) => n = s) bracks of
              NONE => NONE
            | SOME (_,b1,b2) => SOME (b1,tm,b2)
          end
        | destBrack _ = NONE

      fun isBrack tm = Option.isSome (destBrack tm)

      fun functionArgument bv tm =
          Print.sequence
            (Print.addBreak 1)
            (if isBrack tm then customBracket bv tm
             else if isVar tm orelse isConst tm then basic bv tm
             else bracket bv tm)

      and basic bv (Var v) = varName bv v
        | basic bv (Fn (f,args)) =
          Print.blockProgram Print.Inconsistent 2
            (functionName bv f :: map (functionArgument bv) args)

      and customBracket bv tm =
          case destBrack tm of
            SOME (b1,tm,b2) => Print.ppBracket b1 b2 (term bv) tm
          | NONE => basic bv tm

      and innerQuant bv tm =
          case destQuant tm of
            NONE => term bv tm
          | SOME (q,v,vs,tm) =>
            let
              val bv = StringSet.addList bv (map Name.toString (v :: vs))
            in
              Print.program
                [Print.addString q,
                 varName bv v,
                 Print.program
                   (map (Print.sequence (Print.addBreak 1) o varName bv) vs),
                 Print.addString ".",
                 Print.addBreak 1,
                 innerQuant bv tm]
            end

      and quantifier bv tm =
          if not (isQuant tm) then customBracket bv tm
          else Print.block Print.Inconsistent 2 (innerQuant bv tm)

      and molecule bv (tm,r) =
          let
            val (n,tm) = stripNeg tm
          in
            Print.blockProgram Print.Inconsistent n
              [Print.duplicate n (Print.addString neg),
               if isI tm orelse (r andalso isQuant tm) then bracket bv tm
               else quantifier bv tm]
          end

      and term bv tm = iPrinter (molecule bv) (tm,false)

      and bracket bv tm = Print.ppBracket "(" ")" (term bv) tm
    in
      term StringSet.empty
    end inputTerm;

val toString = Print.toString pp;

(* Parsing *)

local
  open Parse;

  infixr 9 >>++
  infixr 8 ++
  infixr 7 >>
  infixr 6 ||

  val isAlphaNum =
      let
        val alphaNumChars = explode "_'"
      in
        fn c => mem c alphaNumChars orelse Char.isAlphaNum c
      end;

  local
    val alphaNumToken = atLeastOne (some isAlphaNum) >> implode;

    val symbolToken =
        let
          fun isNeg c = str c = !negation

          val symbolChars = explode "<>=-*+/\\?@|!$%&#^:;~"

          fun isSymbol c = mem c symbolChars

          fun isNonNegSymbol c = not (isNeg c) andalso isSymbol c
        in
          some isNeg >> str ||
          (some isNonNegSymbol ++ many (some isSymbol)) >> (implode o op::)
        end;

    val punctToken =
        let
          val punctChars = explode "()[]{}.,"

          fun isPunct c = mem c punctChars
        in
          some isPunct >> str
        end;

    val lexToken = alphaNumToken || symbolToken || punctToken;

    val space = many (some Char.isSpace);
  in
    val lexer = (space ++ lexToken ++ space) >> (fn (_,(tok,_)) => tok);
  end;

  fun termParser inputStream =
      let
        val quants = !binders
        and iOps = !infixes
        and neg = !negation
        and bracks = ("(",")") :: !brackets

        val bracks = map (fn (b1,b2) => (b1 ^ b2, b1, b2)) bracks

        val bTokens = map #2 bracks @ map #3 bracks

        fun possibleVarName "" = false
          | possibleVarName s = isAlphaNum (String.sub (s,0))

        fun vName bv s = StringSet.member s bv

        val iTokens = Print.tokensInfixes iOps

        val iParser =
            parseInfixes iOps (fn (f,a,b) => Fn (Name.fromString f, [a,b]))

        val specialTokens =
            StringSet.addList iTokens (neg :: quants @ ["$"] @ bTokens)

        fun varName bv =
            some (vName bv) ||
            (some (Useful.equal "$") ++ some possibleVarName) >> snd

        fun fName bv s =
            not (StringSet.member s specialTokens) andalso not (vName bv s)

        fun functionName bv =
            some (fName bv) ||
            (some (Useful.equal "(") ++ any ++ some (Useful.equal ")")) >>
            (fn (_,(s,_)) => s)

        fun basic bv tokens =
            let
              val var = varName bv >> (Var o Name.fromString)

              val const =
                  functionName bv >> (fn f => Fn (Name.fromString f, []))

              fun bracket (ab,a,b) =
                  (some (Useful.equal a) ++ term bv ++ some (Useful.equal b)) >>
                  (fn (_,(tm,_)) =>
                      if ab = "()" then tm else Fn (Name.fromString ab, [tm]))

              fun quantifier q =
                  let
                    fun bind (v,t) =
                        Fn (Name.fromString q, [Var (Name.fromString v), t])
                  in
                    (some (Useful.equal q) ++
                     atLeastOne (some possibleVarName) ++
                     some (Useful.equal ".")) >>++
                    (fn (_,(vs,_)) =>
                        term (StringSet.addList bv vs) >>
                        (fn body => foldr bind body vs))
                  end
            in
              var ||
              const ||
              first (map bracket bracks) ||
              first (map quantifier quants)
            end tokens

        and molecule bv tokens =
            let
              val negations = many (some (Useful.equal neg)) >> length

              val function =
                  (functionName bv ++ many (basic bv)) >>
                  (fn (f,args) => Fn (Name.fromString f, args)) ||
                  basic bv
            in
              (negations ++ function) >>
              (fn (n,tm) => funpow n (fn t => Fn (Name.fromString neg, [t])) tm)
            end tokens

        and term bv tokens = iParser (molecule bv) tokens
      in
        term StringSet.empty
      end inputStream;
in
  fun fromString input =
      let
        val chars = Stream.fromList (explode input)

        val tokens = everything (lexer >> singleton) chars

        val terms = everything (termParser >> singleton) tokens
      in
        case Stream.toList terms of
          [tm] => tm
        | _ => raise Error "Term.fromString"
      end;
end;

local
  val antiquotedTermToString = Print.toString (Print.ppBracket "(" ")" pp);
in
  val parse = Parse.parseQuotation antiquotedTermToString fromString;
end;

end

structure TermOrdered =
struct type t = Term.term val compare = Term.compare end

structure TermMap = KeyMap (TermOrdered);

structure TermSet = ElementSet (TermMap);