src/Tools/Metis/src/Rewrite.sml
author wenzelm
Fri, 02 Oct 2009 22:15:08 +0200
changeset 32861 105f40051387
parent 23510 4521fead5609
child 39353 7f11d833d65b
permissions -rw-r--r--
eliminated dead code;

(* ========================================================================= *)
(* ORDERED REWRITING FOR FIRST ORDER TERMS                                   *)
(* Copyright (c) 2003-2006 Joe Hurd, distributed under the BSD License *)
(* ========================================================================= *)

structure Rewrite :> Rewrite =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* A type of rewrite systems.                                                *)
(* ------------------------------------------------------------------------- *)

datatype orient = LeftToRight | RightToLeft;

type reductionOrder = Term.term * Term.term -> order option;

type equationId = int;

type equation = Rule.equation;

datatype rewrite =
    Rewrite of
      {order : reductionOrder,
       known : (equation * orient option) IntMap.map,
       redexes : (equationId * orient) TermNet.termNet,
       subterms : (equationId * bool * Term.path) TermNet.termNet,
       waiting : IntSet.set};

fun updateWaiting rw waiting =
    let
      val Rewrite {order, known, redexes, subterms, waiting = _} = rw
    in
      Rewrite
        {order = order, known = known, redexes = redexes,
         subterms = subterms, waiting = waiting}
    end;

fun deleteWaiting (rw as Rewrite {waiting,...}) id =
    updateWaiting rw (IntSet.delete waiting id);

(* ------------------------------------------------------------------------- *)
(* Basic operations                                                          *)
(* ------------------------------------------------------------------------- *)

fun new order =
    Rewrite
      {order = order,
       known = IntMap.new (),
       redexes = TermNet.new {fifo = false},
       subterms = TermNet.new {fifo = false},
       waiting = IntSet.empty};

fun peek (Rewrite {known,...}) id = IntMap.peek known id;

fun size (Rewrite {known,...}) = IntMap.size known;

fun equations (Rewrite {known,...}) =
    IntMap.foldr (fn (_,(eqn,_),eqns) => eqn :: eqns) [] known;

val pp = Parser.ppMap equations (Parser.ppList Rule.ppEquation);

(*DEBUG
local
  fun orientOptionToString ort =
      case ort of
        SOME LeftToRight => "-->"
      | SOME RightToLeft => "<--"
      | NONE => "<->";

  open Parser;

  fun ppEq p ((x_y,_),ort) =
      ppBinop (" " ^ orientOptionToString ort) Term.pp Term.pp p x_y;

  fun ppField f ppA p a =
      (beginBlock p Inconsistent 2;
       addString p (f ^ " =");
       addBreak p (1,0);
       ppA p a;
       endBlock p);

  val ppKnown =
      ppField "known" (ppMap IntMap.toList (ppList (ppPair ppInt ppEq)));

  val ppRedexes =
      ppField
        "redexes"
        (TermNet.pp
           (ppPair ppInt (ppMap (orientOptionToString o SOME) ppString)));

  val ppSubterms =
      ppField
        "subterms"
        (TermNet.pp
           (ppMap
              (fn (i,l,p) => (i, (if l then 0 else 1) :: p))
              (ppPair ppInt Term.ppPath)));

  val ppWaiting = ppField "waiting" (ppMap (IntSet.toList) (ppList ppInt));
in
  fun pp p (Rewrite {known,redexes,subterms,waiting,...}) =
      (beginBlock p Inconsistent 2;
       addString p "Rewrite";
       addBreak p (1,0);
       beginBlock p Inconsistent 1;
       addString p "{";
       ppKnown p known;
(*TRACE5
       addString p ",";
       addBreak p (1,0);
       ppRedexes p redexes;
       addString p ",";
       addBreak p (1,0);
       ppSubterms p subterms;
       addString p ",";
       addBreak p (1,0);
       ppWaiting p waiting;
*)
       endBlock p;
       addString p "}";
       endBlock p);
end;
*)

val toString = Parser.toString pp;

(* ------------------------------------------------------------------------- *)
(* Debug functions.                                                          *)
(* ------------------------------------------------------------------------- *)

fun termReducible order known id =
    let
      fun eqnRed ((l,r),_) tm =
          case total (Subst.match Subst.empty l) tm of
            NONE => false
          | SOME sub =>
            order (tm, Subst.subst (Subst.normalize sub) r) = SOME GREATER
      
      fun knownRed tm (eqnId,(eqn,ort)) =
          eqnId <> id andalso
          ((ort <> SOME RightToLeft andalso eqnRed eqn tm) orelse
           (ort <> SOME LeftToRight andalso eqnRed (Rule.symEqn eqn) tm))

      fun termRed tm = IntMap.exists (knownRed tm) known orelse subtermRed tm
      and subtermRed (Term.Var _) = false
        | subtermRed (Term.Fn (_,tms)) = List.exists termRed tms
    in
      termRed
    end;

fun literalReducible order known id lit =
    List.exists (termReducible order known id) (Literal.arguments lit);

fun literalsReducible order known id lits =
    LiteralSet.exists (literalReducible order known id) lits;

fun thmReducible order known id th =
    literalsReducible order known id (Thm.clause th);

(* ------------------------------------------------------------------------- *)
(* Add equations into the system.                                            *)
(* ------------------------------------------------------------------------- *)

fun orderToOrient (SOME EQUAL) = raise Error "Rewrite.orient: reflexive"
  | orderToOrient (SOME GREATER) = SOME LeftToRight
  | orderToOrient (SOME LESS) = SOME RightToLeft
  | orderToOrient NONE = NONE;

local
  fun ins redexes redex id ort = TermNet.insert redexes (redex,(id,ort));
in
  fun addRedexes id (((l,r),_),ort) redexes =
      case ort of
        SOME LeftToRight => ins redexes l id LeftToRight
      | SOME RightToLeft => ins redexes r id RightToLeft
      | NONE => ins (ins redexes l id LeftToRight) r id RightToLeft;
end;

fun add (rw as Rewrite {known,...}) (id,eqn) =
    if IntMap.inDomain id known then rw
    else
      let
        val Rewrite {order,redexes,subterms,waiting, ...} = rw
        val ort = orderToOrient (order (fst eqn))
        val known = IntMap.insert known (id,(eqn,ort))
        val redexes = addRedexes id (eqn,ort) redexes
        val waiting = IntSet.add waiting id
        val rw =
            Rewrite
              {order = order, known = known, redexes = redexes,
               subterms = subterms, waiting = waiting}
(*TRACE5
        val () = Parser.ppTrace pp "Rewrite.add: result" rw
*)
      in
        rw
      end;

val addList = foldl (fn (eqn,rw) => add rw eqn);

(* ------------------------------------------------------------------------- *)
(* Rewriting (the order must be a refinement of the rewrite order).          *)
(* ------------------------------------------------------------------------- *)

local
  fun reorder ((i,_),(j,_)) = Int.compare (j,i);
in
  fun matchingRedexes redexes tm = sort reorder (TermNet.match redexes tm);
end;

fun wellOriented NONE _ = true
  | wellOriented (SOME LeftToRight) LeftToRight = true
  | wellOriented (SOME RightToLeft) RightToLeft = true
  | wellOriented _ _ = false;

fun redexResidue LeftToRight ((l_r,_) : equation) = l_r
  | redexResidue RightToLeft ((l,r),_) = (r,l);

fun orientedEquation LeftToRight eqn = eqn
  | orientedEquation RightToLeft eqn = Rule.symEqn eqn;

fun rewrIdConv' order known redexes id tm =
    let
      fun rewr (id',lr) =
          let
            val _ = id <> id' orelse raise Error "same theorem"
            val (eqn,ort) = IntMap.get known id'
            val _ = wellOriented ort lr orelse raise Error "orientation"
            val (l,r) = redexResidue lr eqn
            val sub = Subst.normalize (Subst.match Subst.empty l tm)
            val tm' = Subst.subst sub r
            val _ = Option.isSome ort orelse
                    order (tm,tm') = SOME GREATER orelse
                    raise Error "order"
            val (_,th) = orientedEquation lr eqn
          in
            (tm', Thm.subst sub th)
          end
    in
      case first (total rewr) (matchingRedexes redexes tm) of
        NONE => raise Error "Rewrite.rewrIdConv: no matching rewrites"
      | SOME res => res
    end;

fun rewriteIdConv' order known redexes id =
    if IntMap.null known then Rule.allConv
    else Rule.repeatTopDownConv (rewrIdConv' order known redexes id);

fun mkNeqConv order lit =
    let
      val (l,r) = Literal.destNeq lit
    in
      case order (l,r) of
        NONE => raise Error "incomparable"
      | SOME LESS =>
        let
          val sub = Subst.fromList [("x",l),("y",r)]
          val th = Thm.subst sub Rule.symmetry
        in
          fn tm => if tm = r then (l,th) else raise Error "mkNeqConv: RL"
        end
      | SOME EQUAL => raise Error "irreflexive"
      | SOME GREATER =>
        let
          val th = Thm.assume lit
        in
          fn tm => if tm = l then (r,th) else raise Error "mkNeqConv: LR"
        end
    end;

datatype neqConvs = NeqConvs of Rule.conv LiteralMap.map;

val neqConvsEmpty = NeqConvs (LiteralMap.new ());

fun neqConvsNull (NeqConvs m) = LiteralMap.null m;

fun neqConvsAdd order (neq as NeqConvs m) lit =
    case total (mkNeqConv order) lit of
      NONE => NONE
    | SOME conv => SOME (NeqConvs (LiteralMap.insert m (lit,conv)));

fun mkNeqConvs order =
    let
      fun add (lit,(neq,lits)) =
          case neqConvsAdd order neq lit of
            SOME neq => (neq,lits)
          | NONE => (neq, LiteralSet.add lits lit)
    in
      LiteralSet.foldl add (neqConvsEmpty,LiteralSet.empty)
    end;

fun neqConvsDelete (NeqConvs m) lit = NeqConvs (LiteralMap.delete m lit);

fun neqConvsToConv (NeqConvs m) =
    Rule.firstConv (LiteralMap.foldr (fn (_,c,l) => c :: l) [] m);

fun neqConvsFoldl f b (NeqConvs m) =
    LiteralMap.foldl (fn (l,_,z) => f (l,z)) b m;

fun neqConvsRewrIdLiterule order known redexes id neq =
    if IntMap.null known andalso neqConvsNull neq then Rule.allLiterule
    else
      let
        val neq_conv = neqConvsToConv neq
        val rewr_conv = rewrIdConv' order known redexes id
        val conv = Rule.orelseConv neq_conv rewr_conv
        val conv = Rule.repeatTopDownConv conv
      in
        Rule.allArgumentsLiterule conv
      end;

fun rewriteIdEqn' order known redexes id (eqn as (l_r,th)) =
    let
      val (neq,_) = mkNeqConvs order (Thm.clause th)
      val literule = neqConvsRewrIdLiterule order known redexes id neq
      val (strongEqn,lit) =
          case Rule.equationLiteral eqn of
            NONE => (true, Literal.mkEq l_r)
          | SOME lit => (false,lit)
      val (lit',litTh) = literule lit
    in
      if lit = lit' then eqn
      else
        (Literal.destEq lit',
         if strongEqn then th
         else if not (Thm.negateMember lit litTh) then litTh
         else Thm.resolve lit th litTh)
    end
(*DEBUG
    handle Error err => raise Error ("Rewrite.rewriteIdEqn':\n" ^ err);
*)

fun rewriteIdLiteralsRule' order known redexes id lits th =
    let
      val mk_literule = neqConvsRewrIdLiterule order known redexes id

      fun rewr_neq_lit (lit, acc as (changed,neq,lits,th)) =
          let
            val neq = neqConvsDelete neq lit
            val (lit',litTh) = mk_literule neq lit
          in
            if lit = lit' then acc
            else
              let
                val th = Thm.resolve lit th litTh
              in
                case neqConvsAdd order neq lit' of
                  SOME neq => (true,neq,lits,th)
                | NONE => (changed, neq, LiteralSet.add lits lit', th)
              end
          end

      fun rewr_neq_lits neq lits th =
          let
            val (changed,neq,lits,th) =
                neqConvsFoldl rewr_neq_lit (false,neq,lits,th) neq
          in
            if changed then rewr_neq_lits neq lits th
            else (neq,lits,th)
          end

      val (neq,lits) = mkNeqConvs order lits

      val (neq,lits,th) = rewr_neq_lits neq lits th

      val rewr_literule = mk_literule neq

      fun rewr_lit (lit,th) =
          if Thm.member lit th then Rule.literalRule rewr_literule lit th
          else th
    in
      LiteralSet.foldl rewr_lit th lits
    end;

fun rewriteIdRule' order known redexes id th =
    rewriteIdLiteralsRule' order known redexes id (Thm.clause th) th;

(*DEBUG
val rewriteIdRule' = fn order => fn known => fn redexes => fn id => fn th =>
    let
(*TRACE6
      val () = Parser.ppTrace Thm.pp "Rewrite.rewriteIdRule': th" th
*)
      val result = rewriteIdRule' order known redexes id th
(*TRACE6
      val () = Parser.ppTrace Thm.pp "Rewrite.rewriteIdRule': result" result
*)
      val _ = not (thmReducible order known id result) orelse
              raise Bug "rewriteIdRule: should be normalized"
    in
      result
    end
    handle Error err => raise Error ("Rewrite.rewriteIdRule:\n" ^ err);
*)

fun rewrIdConv (Rewrite {known,redexes,...}) order =
    rewrIdConv' order known redexes;

fun rewrConv rewrite order = rewrIdConv rewrite order ~1;

fun rewriteIdConv (Rewrite {known,redexes,...}) order =
    rewriteIdConv' order known redexes;

fun rewriteConv rewrite order = rewriteIdConv rewrite order ~1;

fun rewriteIdLiteralsRule (Rewrite {known,redexes,...}) order =
    rewriteIdLiteralsRule' order known redexes;

fun rewriteLiteralsRule rewrite order =
    rewriteIdLiteralsRule rewrite order ~1;

fun rewriteIdRule (Rewrite {known,redexes,...}) order =
    rewriteIdRule' order known redexes;

fun rewriteRule rewrite order = rewriteIdRule rewrite order ~1;

(* ------------------------------------------------------------------------- *)
(* Inter-reduce the equations in the system.                                 *)
(* ------------------------------------------------------------------------- *)

fun addSubterms id (((l,r),_) : equation) subterms =
    let
      fun addSubterm b ((path,tm),net) = TermNet.insert net (tm,(id,b,path))

      val subterms = foldl (addSubterm true) subterms (Term.subterms l)
      val subterms = foldl (addSubterm false) subterms (Term.subterms r)
    in
      subterms
    end;

fun sameRedexes NONE _ _ = false
  | sameRedexes (SOME LeftToRight) (l0,_) (l,_) = l0 = l
  | sameRedexes (SOME RightToLeft) (_,r0) (_,r) = r0 = r;

fun redexResidues NONE (l,r) = [(l,r,false),(r,l,false)]
  | redexResidues (SOME LeftToRight) (l,r) = [(l,r,true)]
  | redexResidues (SOME RightToLeft) (l,r) = [(r,l,true)];

fun findReducibles order known subterms id =
    let
      fun checkValidRewr (l,r,ord) id' left path =
          let
            val (((x,y),_),_) = IntMap.get known id'
            val tm = Term.subterm (if left then x else y) path
            val sub = Subst.match Subst.empty l tm
          in
            if ord then ()
            else
              let
                val tm' = Subst.subst (Subst.normalize sub) r
              in
                if order (tm,tm') = SOME GREATER then ()
                else raise Error "order"
              end
          end
            
      fun addRed lr ((id',left,path),todo) =
          if id <> id' andalso not (IntSet.member id' todo) andalso
             can (checkValidRewr lr id' left) path
          then IntSet.add todo id'
          else todo
               
      fun findRed (lr as (l,_,_), todo) =
          List.foldl (addRed lr) todo (TermNet.matched subterms l)
    in
      List.foldl findRed
    end;

fun reduce1 new id (eqn0,ort0) (rpl,spl,todo,rw,changed) =
    let
      val (eq0,_) = eqn0
      val Rewrite {order,known,redexes,subterms,waiting} = rw
      val eqn as (eq,_) = rewriteIdEqn' order known redexes id eqn0
      val identical = eq = eq0
      val same_redexes = identical orelse sameRedexes ort0 eq0 eq
      val rpl = if same_redexes then rpl else IntSet.add rpl id
      val spl = if new orelse identical then spl else IntSet.add spl id
      val changed =
          if not new andalso identical then changed else IntSet.add changed id
      val ort =
          if same_redexes then SOME ort0 else total orderToOrient (order eq)
    in
      case ort of
        NONE =>
        let
          val known = IntMap.delete known id
          val rw =
              Rewrite
                {order = order, known = known, redexes = redexes,
                 subterms = subterms, waiting = waiting}
        in
          (rpl,spl,todo,rw,changed)
        end
      | SOME ort =>
        let
          val todo =
              if not new andalso same_redexes then todo
              else
                findReducibles
                  order known subterms id todo (redexResidues ort eq)
          val known =
              if identical then known else IntMap.insert known (id,(eqn,ort))
          val redexes =
              if same_redexes then redexes
              else addRedexes id (eqn,ort) redexes
          val subterms =
              if new orelse not identical then addSubterms id eqn subterms
              else subterms
          val rw =
              Rewrite
                {order = order, known = known, redexes = redexes,
                 subterms = subterms, waiting = waiting}
        in
          (rpl,spl,todo,rw,changed)
        end
    end;

fun pick known set =
    let
      fun oriented id =
          case IntMap.peek known id of
            SOME (x as (_, SOME _)) => SOME (id,x)
          | _ => NONE

      fun any id =
          case IntMap.peek known id of SOME x => SOME (id,x) | _ => NONE
    in
      case IntSet.firstl oriented set of
        x as SOME _ => x
      | NONE => IntSet.firstl any set
    end;

local
  fun cleanRedexes known redexes rpl =
      if IntSet.null rpl then redexes
      else
        let
          fun filt (id,_) = not (IntSet.member id rpl)

          fun addReds (id,reds) =
              case IntMap.peek known id of
                NONE => reds
              | SOME eqn_ort => addRedexes id eqn_ort reds
                               
          val redexes = TermNet.filter filt redexes
          val redexes = IntSet.foldl addReds redexes rpl
        in
          redexes
        end;

  fun cleanSubterms known subterms spl =
      if IntSet.null spl then subterms
      else
        let
          fun filt (id,_,_) = not (IntSet.member id spl)

          fun addSubtms (id,subtms) =
              case IntMap.peek known id of
                NONE => subtms
              | SOME (eqn,_) => addSubterms id eqn subtms
                               
          val subterms = TermNet.filter filt subterms
          val subterms = IntSet.foldl addSubtms subterms spl
        in
          subterms
        end;
in
  fun rebuild rpl spl rw =
      let
(*TRACE5
        val ppPl = Parser.ppMap IntSet.toList (Parser.ppList Parser.ppInt)
        val () = Parser.ppTrace ppPl "Rewrite.rebuild: rpl" rpl
        val () = Parser.ppTrace ppPl "Rewrite.rebuild: spl" spl
*)
        val Rewrite {order,known,redexes,subterms,waiting} = rw
        val redexes = cleanRedexes known redexes rpl
        val subterms = cleanSubterms known subterms spl
      in
        Rewrite
          {order = order, known = known, redexes = redexes,
           subterms = subterms, waiting = waiting}
      end;
end;

fun reduceAcc (rpl, spl, todo, rw as Rewrite {known,waiting,...}, changed) =
    case pick known todo of
      SOME (id,eqn_ort) =>
      let
        val todo = IntSet.delete todo id
      in
        reduceAcc (reduce1 false id eqn_ort (rpl,spl,todo,rw,changed))
      end
    | NONE =>
      case pick known waiting of
        SOME (id,eqn_ort) =>
        let
          val rw = deleteWaiting rw id
        in
          reduceAcc (reduce1 true id eqn_ort (rpl,spl,todo,rw,changed))
        end
      | NONE => (rebuild rpl spl rw, IntSet.toList changed);

fun isReduced (Rewrite {waiting,...}) = IntSet.null waiting;

fun reduce' rw =
    if isReduced rw then (rw,[])
    else reduceAcc (IntSet.empty,IntSet.empty,IntSet.empty,rw,IntSet.empty);

(*DEBUG
val reduce' = fn rw =>
    let
(*TRACE4
      val () = Parser.ppTrace pp "Rewrite.reduce': rw" rw
*)
      val Rewrite {known,order,...} = rw
      val result as (Rewrite {known = known', ...}, _) = reduce' rw
(*TRACE4
      val ppResult = Parser.ppPair pp (Parser.ppList Parser.ppInt)
      val () = Parser.ppTrace ppResult "Rewrite.reduce': result" result
*)
      val ths = map (fn (id,((_,th),_)) => (id,th)) (IntMap.toList known')
      val _ =
          not (List.exists (uncurry (thmReducible order known')) ths) orelse
          raise Bug "Rewrite.reduce': not fully reduced"
    in
      result
    end
    handle Error err => raise Bug ("Rewrite.reduce': shouldn't fail\n" ^ err);
*)

fun reduce rw = fst (reduce' rw);

(* ------------------------------------------------------------------------- *)
(* Rewriting as a derived rule.                                              *)
(* ------------------------------------------------------------------------- *)

local
  fun addEqn (id_eqn,rw) = add rw id_eqn;
in
  fun orderedRewrite order ths =
    let
      val rw = foldl addEqn (new order) (enumerate ths)
    in
      rewriteRule rw order
    end;
end;

val rewrite = orderedRewrite (K (SOME GREATER));

end