src/Provers/blast.ML
changeset 24062 845c0d693328
parent 23985 83e6e9ad0f4f
child 24099 6534fd4c5d46
--- a/src/Provers/blast.ML	Sun Jul 29 19:46:02 2007 +0200
+++ b/src/Provers/blast.ML	Sun Jul 29 19:46:03 2007 +0200
@@ -85,17 +85,13 @@
   val trace             : bool ref
   val fullTrace         : branch list list ref
   val fromType          : (indexname * term) list ref -> Term.typ -> term
-  val fromTerm          : Term.term -> term
-  val fromSubgoal       : Term.term -> term
+  val fromTerm          : theory -> Term.term -> term
+  val fromSubgoal       : theory -> Term.term -> term
   val instVars          : term -> (unit -> unit)
   val toTerm            : int -> term -> Term.term
   val readGoal          : theory -> string -> term
   val tryInThy          : theory -> int -> string ->
                   (int->tactic) list * branch list list * (int*int*exn) list
-  val trygl             : claset -> int -> int ->
-                  (int->tactic) list * branch list list * (int*int*exn) list
-  val Trygl             : int -> int ->
-                  (int->tactic) list * branch list list * (int*int*exn) list
   val normBr            : branch -> branch
   end;
 
@@ -117,6 +113,41 @@
   | Abs    of string*term
   | op $   of term*term;
 
+(*Pending formulae carry md (may duplicate) flags*)
+type branch =
+    {pairs: ((term*bool) list * (*safe formulae on this level*)
+               (term*bool) list) list,  (*haz formulae  on this level*)
+     lits:   term list,                 (*literals: irreducible formulae*)
+     vars:   term option ref list,      (*variables occurring in branch*)
+     lim:    int};                      (*resource limit*)
+
+
+(* global state information *)
+
+datatype state = State of
+ {thy: theory,
+  fullTrace: branch list list ref,
+  trail: term option ref list ref,
+  ntrail: int ref,
+  nclosed: int ref,
+  ntried: int ref}
+
+fun reject_const thy c =
+  is_some (Sign.const_type thy c) andalso
+    error ("blast: theory contains illegal constant " ^ quote c);
+
+fun initialize thy =
+ (reject_const thy "*Goal*";
+  reject_const thy "*False*";
+  State
+   {thy = thy,
+    fullTrace = ref [],
+    trail = ref [],
+    ntrail = ref 0,
+    nclosed = ref 0,  (*branches closed: number of branches closed during the search*)
+    ntried = ref 1}); (*branches tried: number of branches created by splitting (counting from 1)*)
+
+
 
 (** Basic syntactic operations **)
 
@@ -176,11 +207,8 @@
                            end
                  | SOME v => v)
 
-(*refer to the theory in which blast is initialized*)
-val typargs = ref (fn ((_, T): string * typ) => [T]);
-
-fun fromConst alist (a, T) =
-  Const (a, map (fromType alist) (! typargs (a, T)));
+fun fromConst thy alist (a, T) =
+  Const (a, map (fromType alist) (Sign.const_typargs thy (a, T)));
 
 
 (*Tests whether 2 terms are alpha-convertible; chases instantiations*)
@@ -324,12 +352,9 @@
 
 exception UNIFY;
 
-val trail = ref [] : term option ref list ref;
-val ntrail = ref 0;
-
 
 (*Restore the trail to some previous state: for backtracking*)
-fun clearTo n =
+fun clearTo (State {ntrail, trail, ...}) n =
     while !ntrail<>n do
         (hd(!trail) := NONE;
          trail := tl (!trail);
@@ -340,8 +365,9 @@
   "vars" is a list of variables local to the rule and NOT to be put
         on the trail (no point in doing so)
 *)
-fun unify(vars,t,u) =
-    let val n = !ntrail
+fun unify state (vars,t,u) =
+    let val State {ntrail, trail, ...} = state
+        val n = !ntrail
         fun update (t as Var v, u) =
             if t aconv u then ()
             else if varOccur v u then raise UNIFY
@@ -364,16 +390,16 @@
         and unifysAux ([], []) = ()
           | unifysAux (t :: ts, u :: us) = (unifyAux (t, u); unifysAux (ts, us))
           | unifysAux _ = raise UNIFY;
-    in  (unifyAux(t,u); true) handle UNIFY => (clearTo n; false)
+    in  (unifyAux(t,u); true) handle UNIFY => (clearTo state n; false)
     end;
 
 
 (*Convert from "real" terms to prototerms; eta-contract.
   Code is similar to fromSubgoal.*)
-fun fromTerm t =
+fun fromTerm thy t =
   let val alistVar = ref []
       and alistTVar = ref []
-      fun from (Term.Const aT) = fromConst alistTVar aT
+      fun from (Term.Const aT) = fromConst thy alistTVar aT
         | from (Term.Free  (a,_)) = Free a
         | from (Term.Bound i)     = Bound i
         | from (Term.Var (ixn,T)) =
@@ -479,8 +505,8 @@
 (*Tableau rule from elimination rule.
   Flag "upd" says that the inference updated the branch.
   Flag "dup" requests duplication of the affected formula.*)
-fun fromRule vars rl =
-  let val trl = rl |> Thm.prop_of |> fromTerm |> convertRule vars
+fun fromRule thy vars rl =
+  let val trl = rl |> Thm.prop_of |> fromTerm thy |> convertRule vars
       fun tac (upd, dup,rot) i =
         emtac upd (if dup then rev_dup_elim rl else rl) i
         THEN
@@ -511,8 +537,8 @@
   Flag "dup" requests duplication of the affected formula.
   Since haz rules are now delayed, "dup" is always FALSE for
   introduction rules.*)
-fun fromIntrRule vars rl =
-  let val trl = rl |> Thm.prop_of |> fromTerm |> convertIntrRule vars
+fun fromIntrRule thy vars rl =
+  let val trl = rl |> Thm.prop_of |> fromTerm thy |> convertIntrRule vars
       fun tac (upd,dup,rot) i =
          rmtac upd (if dup then Data.dup_intr rl else rl) i
          THEN
@@ -534,27 +560,17 @@
   | toTerm d (f $ u)       = Term.$ (toTerm d f, toTerm (d-1) u);
 
 
-fun netMkRules P vars (nps: netpair list) =
+fun netMkRules thy P vars (nps: netpair list) =
   case P of
       (Const ("*Goal*", _) $ G) =>
         let val pG = mk_Trueprop (toTerm 2 G)
             val intrs = maps (fn (inet,_) => Net.unify_term inet pG) nps
-        in  map (fromIntrRule vars o #2) (Tactic.orderlist intrs)  end
+        in  map (fromIntrRule thy vars o #2) (Tactic.orderlist intrs)  end
     | _ =>
         let val pP = mk_Trueprop (toTerm 3 P)
             val elims = maps (fn (_,enet) => Net.unify_term enet pP) nps
-        in  map_filter (fromRule vars o #2) (Tactic.orderlist elims)  end;
-
+        in  map_filter (fromRule thy vars o #2) (Tactic.orderlist elims)  end;
 
-(*Pending formulae carry md (may duplicate) flags*)
-type branch =
-    {pairs: ((term*bool) list * (*safe formulae on this level*)
-               (term*bool) list) list,  (*haz formulae  on this level*)
-     lits:   term list,                 (*literals: irreducible formulae*)
-     vars:   term option ref list,      (*variables occurring in branch*)
-     lim:    int};                      (*resource limit*)
-
-val fullTrace = ref[] : branch list list ref;
 
 (*Normalize a branch--for tracing*)
 fun norm2 (G,md) = (norm G, md);
@@ -598,7 +614,7 @@
   | showTerm d (f $ u)       = if d=0 then dummyVar
                                else Term.$ (showTerm d f, showTerm (d-1) u);
 
-fun string_of sign d t = Sign.string_of_term sign (showTerm d t);
+fun string_of thy d t = Sign.string_of_term thy (showTerm d t);
 
 (*Convert a Goal to an ordinary Not.  Used also in dup_intr, where a goal like
   Ex(P) is duplicated as the assumption ~Ex(P). *)
@@ -616,20 +632,20 @@
 fun negOfGoal_tac i = TRACE Data.ccontr (rtac Data.ccontr) i THEN
                       rotate_tac ~1 i;
 
-fun traceTerm sign t =
+fun traceTerm thy t =
   let val t' = norm (negOfGoal t)
-      val stm = string_of sign 8 t'
+      val stm = string_of thy 8 t'
   in
-      case topType sign t' of
+      case topType thy t' of
           NONE   => stm   (*no type to attach*)
-        | SOME T => stm ^ "\t:: " ^ Sign.string_of_typ sign T
+        | SOME T => stm ^ "\t:: " ^ Sign.string_of_typ thy T
   end;
 
 
 (*Print tracing information at each iteration of prover*)
-fun tracing sign brs =
-  let fun printPairs (((G,_)::_,_)::_)  = Output.immediate_output(traceTerm sign G)
-        | printPairs (([],(H,_)::_)::_) = Output.immediate_output(traceTerm sign H ^ "\t (Unsafe)")
+fun tracing (State {thy, fullTrace, ...}) brs =
+  let fun printPairs (((G,_)::_,_)::_)  = Output.immediate_output(traceTerm thy G)
+        | printPairs (([],(H,_)::_)::_) = Output.immediate_output(traceTerm thy H ^ "\t (Unsafe)")
         | printPairs _                 = ()
       fun printBrs (brs0 as {pairs, lits, lim, ...} :: brs) =
             (fullTrace := brs0 :: !fullTrace;
@@ -643,14 +659,14 @@
 fun traceMsg s = if !trace then writeln s else ();
 
 (*Tracing: variables updated in the last branch operation?*)
-fun traceVars sign ntrl =
+fun traceVars (State {thy, ntrail, trail, ...}) ntrl =
   if !trace then
       (case !ntrail-ntrl of
             0 => ()
           | 1 => Output.immediate_output"\t1 variable UPDATED:"
           | n => Output.immediate_output("\t" ^ Int.toString n ^ " variables UPDATED:");
        (*display the instantiations themselves, though no variable names*)
-       List.app (fn v => Output.immediate_output("   " ^ string_of sign 4 (the (!v))))
+       List.app (fn v => Output.immediate_output("   " ^ string_of thy 4 (the (!v))))
            (List.take(!trail, !ntrail-ntrl));
        writeln"")
     else ();
@@ -736,7 +752,7 @@
 (*Substitute through the branch if an equality goal (else raise DEST_EQ).
   Moves affected literals back into the branch, but it is not clear where
   they should go: this could make proofs fail.*)
-fun equalSubst sign (G, {pairs, lits, vars, lim}) =
+fun equalSubst thy (G, {pairs, lits, vars, lim}) =
   let val (t,u) = orientGoal(dest_eq G)
       val subst = subst_atomic (t,u)
       fun subst2(G,md) = (subst G, md)
@@ -759,8 +775,8 @@
             end
       val (changed, lits') = foldr subLit ([], []) lits
       val (changed', pairs') = foldr subFrame (changed, []) pairs
-  in  if !trace then writeln ("Substituting " ^ traceTerm sign u ^
-                              " for " ^ traceTerm sign t ^ " in branch" )
+  in  if !trace then writeln ("Substituting " ^ traceTerm thy u ^
+                              " for " ^ traceTerm thy t ^ " in branch" )
       else ();
       {pairs = (changed',[])::pairs',   (*affected formulas, and others*)
        lits  = lits',                   (*unaffected literals*)
@@ -781,9 +797,9 @@
 val eAssume_tac = TRACE asm_rl   (eq_assume_tac ORELSE' assume_tac);
 
 (*Try to unify complementary literals and return the corresponding tactic. *)
-fun tryClose (G, L) =
+fun tryClose state (G, L) =
   let
-    fun close t u tac = if unify ([], t, u) then SOME tac else NONE;
+    fun close t u tac = if unify state ([], t, u) then SOME tac else NONE;
     fun arg (_ $ t) = t;
   in
     if isGoal G then close (arg G) L eAssume_tac
@@ -820,9 +836,9 @@
 (*nbrs = # of branches just prior to closing this one.  Delete choice points
   for goals proved by the latest inference, provided NO variables in the
   next branch have been updated.*)
-fun prune (1, nxtVars, choices) = choices  (*DON'T prune at very end: allow
+fun prune _ (1, nxtVars, choices) = choices  (*DON'T prune at very end: allow
                                              backtracking over bad proofs*)
-  | prune (nbrs: int, nxtVars, choices) =
+  | prune (State {ntrail, trail, ...}) (nbrs: int, nxtVars, choices) =
       let fun traceIt last =
                 let val ll = length last
                     and lc = length choices
@@ -895,12 +911,7 @@
   | matchs (t :: ts) (u :: us) = match t u andalso matchs ts us;
 
 
-(*Branches closed: number of branches closed during the search
-  Branches tried:  number of branches created by splitting (counting from 1)*)
-val nclosed = ref 0
-and ntried  = ref 1;
-
-fun printStats (b, start, tacs) =
+fun printStats (State {ntried, nclosed, ...}) (b, start, tacs) =
   if b then
     writeln (end_timing start ^ " for search.  Closed: "
              ^ Int.toString (!nclosed) ^
@@ -914,12 +925,13 @@
   bound on unsafe expansions.
  "start" is CPU time at start, for printing search time
 *)
-fun prove (sign, start, cs, brs, cont) =
- let val {safe0_netpair, safep_netpair, haz_netpair, ...} = Data.rep_cs cs
+fun prove (state, start, cs, brs, cont) =
+ let val State {thy, ntrail, nclosed, ntried, ...} = state;
+     val {safe0_netpair, safep_netpair, haz_netpair, ...} = Data.rep_cs cs
      val safeList = [safe0_netpair, safep_netpair]
      and hazList  = [haz_netpair]
      fun prv (tacs, trs, choices, []) =
-                (printStats (!trace orelse !stats, start, tacs);
+                (printStats state (!trace orelse !stats, start, tacs);
                  cont (tacs, trs, choices))   (*all branches closed!*)
        | prv (tacs, trs, choices,
               brs0 as {pairs = ((G,md)::br, haz)::pairs,
@@ -931,7 +943,7 @@
               val nbrs = length brs0
               val nxtVars = nextVars brs
               val G = norm G
-              val rules = netMkRules G vars safeList
+              val rules = netMkRules thy G vars safeList
               (*Make a new branch, decrementing "lim" if instantiations occur*)
               fun newBr (vars',lim') prems =
                   map (fn prem =>
@@ -952,7 +964,7 @@
                 to branch.*)
               fun deeper [] = raise NEWBRANCHES
                 | deeper (((P,prems),tac)::grls) =
-                    if unify(add_term_vars(P,[]), P, G)
+                    if unify state (add_term_vars(P,[]), P, G)
                     then  (*P comes from the rule; G comes from the branch.*)
                      let val updated = ntrl < !ntrail (*branch updated*)
                          val lim' = if updated
@@ -964,16 +976,16 @@
                          val tacs' = (tac(updated,false,true))
                                      :: tacs  (*no duplication; rotate*)
                      in
-                         traceNew prems;  traceVars sign ntrl;
+                         traceNew prems;  traceVars state ntrl;
                          (if null prems then (*closed the branch: prune!*)
                             (nclosed := !nclosed + 1;
                              prv(tacs',  brs0::trs,
-                                 prune (nbrs, nxtVars, choices'),
+                                 prune state (nbrs, nxtVars, choices'),
                                  brs))
                           else (*prems non-null*)
                           if lim'<0 (*faster to kill ALL the alternatives*)
                           then (traceMsg"Excessive branching: KILLED";
-                                clearTo ntrl;  raise NEWBRANCHES)
+                                clearTo state ntrl;  raise NEWBRANCHES)
                           else
                             (ntried := !ntried + length prems - 1;
                              prv(tacs',  brs0::trs, choices',
@@ -982,7 +994,7 @@
                            if updated then
                                 (*Backtrack at this level.
                                   Reset Vars and try another rule*)
-                                (clearTo ntrl;  deeper grls)
+                                (clearTo state ntrl;  deeper grls)
                            else (*backtrack to previous level*)
                                 backtrack choices
                      end
@@ -990,21 +1002,21 @@
               (*Try to close branch by unifying with head goal*)
               fun closeF [] = raise CLOSEF
                 | closeF (L::Ls) =
-                    case tryClose(G,L) of
+                    case tryClose state (G,L) of
                         NONE     => closeF Ls
                       | SOME tac =>
                             let val choices' =
                                     (if !trace then (Output.immediate_output"branch closed";
-                                                     traceVars sign ntrl)
+                                                     traceVars state ntrl)
                                                else ();
-                                     prune (nbrs, nxtVars,
+                                     prune state (nbrs, nxtVars,
                                             (ntrl, nbrs, PRV) :: choices))
                             in  nclosed := !nclosed + 1;
                                 prv (tac::tacs, brs0::trs, choices', brs)
                                 handle PRV =>
                                     (*reset Vars and try another literal
                                       [this handler is pruned if possible!]*)
-                                 (clearTo ntrl;  closeF Ls)
+                                 (clearTo state ntrl;  closeF Ls)
                             end
               (*Try to unify a queued formula (safe or haz) with head goal*)
               fun closeFl [] = raise CLOSEF
@@ -1012,12 +1024,12 @@
                     closeF (map fst br)
                       handle CLOSEF => closeF (map fst haz)
                         handle CLOSEF => closeFl pairs
-          in tracing sign brs0;
+          in tracing state brs0;
              if lim<0 then (traceMsg "Limit reached.  "; backtrack choices)
              else
              prv (Data.hyp_subst_tac (!trace) :: tacs,
                   brs0::trs,  choices,
-                  equalSubst sign
+                  equalSubst thy
                     (G, {pairs = (br,haz)::pairs,
                          lits  = lits, vars  = vars, lim   = lim})
                     :: brs)
@@ -1025,7 +1037,7 @@
               handle CLOSEF =>   closeFl ((br,haz)::pairs)
                 handle CLOSEF => deeper rules
                   handle NEWBRANCHES =>
-                   (case netMkRules G vars hazList of
+                   (case netMkRules thy G vars hazList of
                        [] => (*there are no plausible haz rules*)
                              (traceMsg "moving formula to literals";
                               prv (tacs, brs0::trs, choices,
@@ -1059,7 +1071,7 @@
           let exception PRV (*backtrack to precisely this recursion!*)
               val H = norm H
               val ntrl = !ntrail
-              val rules = netMkRules H vars hazList
+              val rules = netMkRules thy H vars hazList
               (*new premises of haz rules may NOT be duplicated*)
               fun newPrem (vars,P,dup,lim') prem =
                   let val Gs' = map (fn Q => (Q,false)) prem
@@ -1082,7 +1094,7 @@
                 to branch.*)
               fun deeper [] = raise NEWBRANCHES
                 | deeper (((P,prems),tac)::grls) =
-                    if unify(add_term_vars(P,[]), P, H)
+                    if unify state (add_term_vars(P,[]), P, H)
                     then
                      let val updated = ntrl < !ntrail (*branch updated*)
                          val vars  = vars_in_vars vars
@@ -1122,14 +1134,14 @@
                        if lim'<0 andalso not (null prems)
                        then (*it's faster to kill ALL the alternatives*)
                            (traceMsg"Excessive branching: KILLED";
-                            clearTo ntrl;  raise NEWBRANCHES)
+                            clearTo state ntrl;  raise NEWBRANCHES)
                        else
                          traceNew prems;
                          if !trace andalso dup then Output.immediate_output" (duplicating)"
                                                  else ();
                          if !trace andalso recur then Output.immediate_output" (recursive)"
                                                  else ();
-                         traceVars sign ntrl;
+                         traceVars state ntrl;
                          if null prems then nclosed := !nclosed + 1
                          else ntried := !ntried + length prems - 1;
                          prv(tac' :: tacs,
@@ -1139,12 +1151,12 @@
                           handle PRV =>
                               if mayUndo
                               then (*reset Vars and try another rule*)
-                                   (clearTo ntrl;  deeper grls)
+                                   (clearTo state ntrl;  deeper grls)
                               else (*backtrack to previous level*)
                                    backtrack choices
                      end
                     else deeper grls
-          in tracing sign brs0;
+          in tracing state brs0;
              if lim<1 then (traceMsg "Limit reached.  "; backtrack choices)
              else deeper rules
              handle NEWBRANCHES =>
@@ -1185,7 +1197,7 @@
 exception TRANS of string;
 
 (*Translation of a subgoal: Skolemize all parameters*)
-fun fromSubgoal t =
+fun fromSubgoal thy t =
   let val alistVar = ref []
       and alistTVar = ref []
       fun hdvar ((ix,(v,is))::_) = v
@@ -1201,7 +1213,7 @@
                       "Function unknown's argument not a bound variable"
         in
           case ht of
-              Term.Const aT    => apply (fromConst alistTVar aT)
+              Term.Const aT    => apply (fromConst thy alistTVar aT)
             | Term.Free  (a,_) => apply (Free a)
             | Term.Bound i     => apply (Bound i)
             | Term.Var (ix,_) =>
@@ -1231,25 +1243,15 @@
   in  skoSubgoal 0 (from 0 (discard_foralls t))  end;
 
 
-fun reject_const thy c =
-  if is_some (Sign.const_type thy c) then
-    error ("Blast: theory contains illegal constant " ^ quote c)
-  else ();
-
-fun initialize thy =
- (fullTrace:=[];  trail := [];  ntrail := 0;
-  nclosed := 0;  ntried := 1;  typargs := Sign.const_typargs thy;
-  reject_const thy "*Goal*"; reject_const thy "*False*");
-
-
 (*Tactic using tableau engine and proof reconstruction.
  "start" is CPU time at start, for printing SEARCH time
         (also prints reconstruction time)
  "lim" is depth limit.*)
-fun timing_depth_tac start cs lim i st0 = NAMED_CRITICAL "blast" (fn () =>
-  let val st = (initialize (theory_of_thm st0); Conv.gconv_rule ObjectLogic.atomize_prems i st0);
-      val sign = Thm.theory_of_thm st
-      val skoprem = fromSubgoal (List.nth(prems_of st, i-1))
+fun timing_depth_tac start cs lim i st0 =
+  let val thy = Thm.theory_of_thm st0
+      val state = initialize thy
+      val st = Conv.gconv_rule ObjectLogic.atomize_prems i st0
+      val skoprem = fromSubgoal thy (List.nth(prems_of st, i-1))
       val hyps  = strip_imp_prems skoprem
       and concl = strip_imp_concl skoprem
       fun cont (tacs,_,choices) =
@@ -1266,10 +1268,8 @@
                        else ();
                        Seq.make(fn()=> cell))
           end
-      val forced = Seq.pull
-        (prove (sign, start, cs, [initBranch (mkGoal concl :: hyps, lim)], cont))
-  in Seq.make (fn () => forced) end
-  handle PROVE     => Seq.empty);
+  in prove (state, start, cs, [initBranch (mkGoal concl :: hyps, lim)], cont) end
+  handle PROVE     => Seq.empty
 
 (*Public version with fixed depth*)
 fun depth_tac cs lim i st = timing_depth_tac (start_timing ()) cs lim i st;
@@ -1289,31 +1289,18 @@
 (*** For debugging: these apply the prover to a subgoal and return
      the resulting tactics, trace, etc.                            ***)
 
-(*Translate subgoal i from a proof state*)
-fun trygl cs lim i =
-        let val st = topthm()
-                val sign = Thm.theory_of_thm st
-                val skoprem = (initialize (theory_of_thm st);
-                               fromSubgoal (List.nth(prems_of st, i-1)))
-                val hyps  = strip_imp_prems skoprem
-                and concl = strip_imp_concl skoprem
-        in timeap prove (sign, start_timing (), cs,
-                         [initBranch (mkGoal concl :: hyps, lim)], I)
-        end
-        handle Subscript => error("There is no subgoal " ^ Int.toString i);
-
-fun Trygl lim i = trygl (Data.claset()) lim i;
+val fullTrace = ref ([]: branch list list);
 
 (*Read a string to make an initial, singleton branch*)
-fun readGoal thy s = Sign.read_prop thy s |> fromTerm |> rand |> mkGoal;
+fun readGoal thy s = Sign.read_prop thy s |> fromTerm thy |> rand |> mkGoal;
 
 fun tryInThy thy lim s =
-    (initialize thy;
-     timeap prove (thy,
-                   start_timing(),
-                   Data.claset(),
-                   [initBranch ([readGoal thy s], lim)],
-                   I));
+  let
+    val state as State {fullTrace = ft, ...} = initialize thy;
+    val res = timeap prove
+      (state, start_timing(), Data.claset(), [initBranch ([readGoal thy s], lim)], I);
+    val _ = fullTrace := !ft;
+  in res end;
 
 
 (** method setup **)