src/Pure/search.ML
changeset 62916 621afc4607ec
parent 60940 4c108cce6b35
child 62919 9eb0359d6a77
equal deleted inserted replaced
62915:0f794993485a 62916:621afc4607ec
     7 
     7 
     8 infix 1 THEN_MAYBE THEN_MAYBE';
     8 infix 1 THEN_MAYBE THEN_MAYBE';
     9 
     9 
    10 signature SEARCH =
    10 signature SEARCH =
    11 sig
    11 sig
    12   val trace_DEPTH_FIRST: bool Unsynchronized.ref
       
    13   val DEPTH_FIRST: (thm -> bool) -> tactic -> tactic
    12   val DEPTH_FIRST: (thm -> bool) -> tactic -> tactic
    14   val has_fewer_prems: int -> thm -> bool
    13   val has_fewer_prems: int -> thm -> bool
    15   val IF_UNSOLVED: tactic -> tactic
    14   val IF_UNSOLVED: tactic -> tactic
    16   val SOLVE: tactic -> tactic
    15   val SOLVE: tactic -> tactic
    17   val THEN_MAYBE: tactic * tactic -> tactic
    16   val THEN_MAYBE: tactic * tactic -> tactic
    18   val THEN_MAYBE': ('a -> tactic) * ('a -> tactic) -> 'a -> tactic
    17   val THEN_MAYBE': ('a -> tactic) * ('a -> tactic) -> 'a -> tactic
    19   val DEPTH_SOLVE: tactic -> tactic
    18   val DEPTH_SOLVE: tactic -> tactic
    20   val DEPTH_SOLVE_1: tactic -> tactic
    19   val DEPTH_SOLVE_1: tactic -> tactic
    21   val THEN_ITER_DEEPEN: int -> tactic -> (thm -> bool) -> (int -> tactic) -> tactic
    20   val THEN_ITER_DEEPEN: int -> tactic -> (thm -> bool) -> (int -> tactic) -> tactic
    22   val ITER_DEEPEN: int -> (thm -> bool) -> (int -> tactic) -> tactic
    21   val ITER_DEEPEN: int -> (thm -> bool) -> (int -> tactic) -> tactic
    23   val trace_DEEPEN: bool Unsynchronized.ref
       
    24   val DEEPEN: int * int -> (int -> int -> tactic) -> int -> int -> tactic
    22   val DEEPEN: int * int -> (int -> int -> tactic) -> int -> int -> tactic
    25   val trace_BEST_FIRST: bool Unsynchronized.ref
       
    26   val THEN_BEST_FIRST: tactic -> (thm -> bool) * (thm -> int) -> tactic -> tactic
    23   val THEN_BEST_FIRST: tactic -> (thm -> bool) * (thm -> int) -> tactic -> tactic
    27   val BEST_FIRST: (thm -> bool) * (thm -> int) -> tactic -> tactic
    24   val BEST_FIRST: (thm -> bool) * (thm -> int) -> tactic -> tactic
    28   val BREADTH_FIRST: (thm -> bool) -> tactic -> tactic
    25   val BREADTH_FIRST: (thm -> bool) -> tactic -> tactic
    29   val QUIET_BREADTH_FIRST: (thm -> bool) -> tactic -> tactic
    26   val QUIET_BREADTH_FIRST: (thm -> bool) -> tactic -> tactic
    30   val trace_ASTAR: bool Unsynchronized.ref
       
    31   val THEN_ASTAR: tactic -> (thm -> bool) * (int -> thm -> int) -> tactic -> tactic
    27   val THEN_ASTAR: tactic -> (thm -> bool) * (int -> thm -> int) -> tactic -> tactic
    32   val ASTAR: (thm -> bool) * (int -> thm -> int) -> tactic -> tactic
    28   val ASTAR: (thm -> bool) * (int -> thm -> int) -> tactic -> tactic
    33 end;
    29 end;
    34 
    30 
    35 structure Search: SEARCH =
    31 structure Search: SEARCH =
    36 struct
    32 struct
    37 
    33 
    38 (**** Depth-first search ****)
    34 (**** Depth-first search ****)
    39 
    35 
    40 val trace_DEPTH_FIRST = Unsynchronized.ref false;
       
    41 
       
    42 (*Searches until "satp" reports proof tree as satisfied.
    36 (*Searches until "satp" reports proof tree as satisfied.
    43   Suppresses duplicate solutions to minimize search space.*)
    37   Suppresses duplicate solutions to minimize search space.*)
    44 fun DEPTH_FIRST satp tac =
    38 fun DEPTH_FIRST satp tac =
    45   let
    39   let
    46     val tac = tracify trace_DEPTH_FIRST tac
       
    47     fun depth used [] = NONE
    40     fun depth used [] = NONE
    48       | depth used (q :: qs) =
    41       | depth used (q :: qs) =
    49           (case Seq.pull q of
    42           (case Seq.pull q of
    50             NONE => depth used qs
    43             NONE => depth used qs
    51           | SOME (st, stq) =>
    44           | SOME (st, stq) =>
    52               if satp st andalso not (member Thm.eq_thm used st) then
    45               if satp st andalso not (member Thm.eq_thm used st) then
    53                 SOME (st, Seq.make (fn() => depth (st :: used) (stq :: qs)))
    46                 SOME (st, Seq.make (fn() => depth (st :: used) (stq :: qs)))
    54               else depth used (tac st :: stq :: qs));
    47               else depth used (tac st :: stq :: qs));
    55   in traced_tac (fn st => depth [] [Seq.single st]) end;
    48   in fn st => Seq.make (fn () => depth [] [Seq.single st]) end;
    56 
    49 
    57 
    50 
    58 (*Predicate: Does the rule have fewer than n premises?*)
    51 (*Predicate: Does the rule have fewer than n premises?*)
    59 fun has_fewer_prems n rule = Thm.nprems_of rule < n;
    52 fun has_fewer_prems n rule = Thm.nprems_of rule < n;
    60 
    53 
    97 fun prune (new as (k', np':int, rgd', stq), qs) =
    90 fun prune (new as (k', np':int, rgd', stq), qs) =
    98   let
    91   let
    99     fun prune_aux (qs, []) = new :: qs
    92     fun prune_aux (qs, []) = new :: qs
   100       | prune_aux (qs, (k, np, rgd, q) :: rqs) =
    93       | prune_aux (qs, (k, np, rgd, q) :: rqs) =
   101           if np' + 1 = np andalso rgd then
    94           if np' + 1 = np andalso rgd then
   102             (if !trace_DEPTH_FIRST then
       
   103                  tracing ("Pruning " ^
       
   104                           string_of_int (1+length rqs) ^ " levels")
       
   105              else ();
       
   106              (*Use OLD k: zero-cost solution; see Stickel, p 365*)
    95              (*Use OLD k: zero-cost solution; see Stickel, p 365*)
   107              (k, np', rgd', stq) :: qs)
    96              (k, np', rgd', stq) :: qs
   108           else prune_aux ((k, np, rgd, q) :: qs, rqs)
    97           else prune_aux ((k, np, rgd, q) :: qs, rqs)
   109       fun take ([], rqs) = ([], rqs)
    98       fun take ([], rqs) = ([], rqs)
   110         | take (arg as ((k, np, rgd, stq) :: qs, rqs)) =
    99         | take (arg as ((k, np, rgd, stq) :: qs, rqs)) =
   111             if np' < np then take (qs, (k, np, rgd, stq) :: rqs) else arg;
   100             if np' < np then take (qs, (k, np, rgd, stq) :: rqs) else arg;
   112   in prune_aux (take (qs, [])) end;
   101   in prune_aux (take (qs, [])) end;
   115 (*Depth-first iterative deepening search for a state that satisfies satp
   104 (*Depth-first iterative deepening search for a state that satisfies satp
   116   tactic tac0 sets up the initial goal queue, while tac1 searches it.
   105   tactic tac0 sets up the initial goal queue, while tac1 searches it.
   117   The solution sequence is redundant: the cutoff heuristic makes it impossible
   106   The solution sequence is redundant: the cutoff heuristic makes it impossible
   118   to suppress solutions arising from earlier searches, as the accumulated cost
   107   to suppress solutions arising from earlier searches, as the accumulated cost
   119   (k) can be wrong.*)
   108   (k) can be wrong.*)
   120 fun THEN_ITER_DEEPEN lim tac0 satp tac1 = traced_tac (fn st =>
   109 fun THEN_ITER_DEEPEN lim tac0 satp tac1 st =
   121   let
   110   let
   122     val countr = Unsynchronized.ref 0
   111     val countr = Unsynchronized.ref 0
   123     and tf = tracify trace_DEPTH_FIRST (tac1 1)
   112     and tf = tac1 1
   124     and qs0 = tac0 st
   113     and qs0 = tac0 st
   125      (*bnd = depth bound; inc = estimate of increment required next*)
   114      (*bnd = depth bound; inc = estimate of increment required next*)
   126     fun depth (bnd, inc) [] =
   115     fun depth (bnd, inc) [] =
   127           if bnd > lim then
   116           if bnd > lim then NONE
   128            (if !trace_DEPTH_FIRST then
       
   129              tracing (string_of_int (! countr) ^
       
   130               " inferences so far.  Giving up at " ^ string_of_int bnd)
       
   131             else ();
       
   132             NONE)
       
   133           else
   117           else
   134            (if !trace_DEPTH_FIRST then
       
   135               tracing (string_of_int (!countr) ^
       
   136                 " inferences so far.  Searching to depth " ^ string_of_int bnd)
       
   137             else ();
       
   138             (*larger increments make it run slower for the hard problems*)
   118             (*larger increments make it run slower for the hard problems*)
   139             depth (bnd + inc, 10)) [(0, 1, false, qs0)]
   119             depth (bnd + inc, 10) [(0, 1, false, qs0)]
   140       | depth (bnd, inc) ((k, np, rgd, q) :: qs) =
   120       | depth (bnd, inc) ((k, np, rgd, q) :: qs) =
   141           if k >= bnd then depth (bnd, inc) qs
   121           if k >= bnd then depth (bnd, inc) qs
   142           else
   122           else
   143            (case
   123            (case (Unsynchronized.inc countr; Seq.pull q) of
   144              (Unsynchronized.inc countr;
       
   145               if !trace_DEPTH_FIRST then
       
   146                 tracing (string_of_int np ^ implode (map (fn _ => "*") qs))
       
   147               else ();
       
   148               Seq.pull q) of
       
   149              NONE => depth (bnd, inc) qs
   124              NONE => depth (bnd, inc) qs
   150            | SOME (st, stq) =>
   125            | SOME (st, stq) =>
   151               if satp st then (*solution!*)
   126               if satp st then (*solution!*)
   152                 SOME(st, Seq.make (fn() => depth (bnd, inc) ((k, np, rgd, stq) :: qs)))
   127                 SOME(st, Seq.make (fn() => depth (bnd, inc) ((k, np, rgd, stq) :: qs)))
   153               else
   128               else
   160                   if k' + np' >= bnd then depth (bnd, Int.min (inc, k' + np' + 1 - bnd)) qs
   135                   if k' + np' >= bnd then depth (bnd, Int.min (inc, k' + np' + 1 - bnd)) qs
   161                   else if np' < np (*solved a subgoal; prune rigid ancestors*)
   136                   else if np' < np (*solved a subgoal; prune rigid ancestors*)
   162                   then depth (bnd, inc) (prune ((k', np', rgd', tf st), (k, np, rgd, stq) :: qs))
   137                   then depth (bnd, inc) (prune ((k', np', rgd', tf st), (k, np, rgd, stq) :: qs))
   163                   else depth (bnd, inc) ((k', np', rgd', tf st) :: (k, np, rgd, stq) :: qs)
   138                   else depth (bnd, inc) ((k', np', rgd', tf st) :: (k, np, rgd, stq) :: qs)
   164                 end)
   139                 end)
   165   in depth (0, 5) [] end);
   140   in Seq.make (fn () => depth (0, 5) []) end;
   166 
   141 
   167 fun ITER_DEEPEN lim = THEN_ITER_DEEPEN lim all_tac;
   142 fun ITER_DEEPEN lim = THEN_ITER_DEEPEN lim all_tac;
   168 
   143 
   169 
   144 
   170 (*Simple iterative deepening tactical.  It merely "deepens" any search tactic
   145 (*Simple iterative deepening tactical.  It merely "deepens" any search tactic
   171   using increment "inc" up to limit "lim". *)
   146   using increment "inc" up to limit "lim". *)
   172 val trace_DEEPEN = Unsynchronized.ref false;
       
   173 
       
   174 fun DEEPEN (inc, lim) tacf m i =
   147 fun DEEPEN (inc, lim) tacf m i =
   175   let
   148   let
   176     fun dpn m st =
   149     fun dpn m st =
   177       st |>
   150       st |>
   178        (if has_fewer_prems i st then no_tac
   151        (if has_fewer_prems i st then no_tac
   179         else if m > lim then
   152         else if m > lim then no_tac
   180           (if !trace_DEEPEN then tracing "Search depth limit exceeded: giving up" else ();
   153         else tacf m i  ORELSE  dpn (m+inc))
   181             no_tac)
       
   182         else
       
   183           (if !trace_DEEPEN then tracing ("Search depth = " ^ string_of_int m) else ();
       
   184             tacf m i  ORELSE  dpn (m+inc)))
       
   185   in  dpn m  end;
   154   in  dpn m  end;
   186 
   155 
   187 
   156 
   188 (*** Best-first search ***)
   157 (*** Best-first search ***)
   189 
   158 
   192 (
   161 (
   193   type elem = int * thm;
   162   type elem = int * thm;
   194   val ord = prod_ord int_ord (Term_Ord.term_ord o apply2 Thm.prop_of);
   163   val ord = prod_ord int_ord (Term_Ord.term_ord o apply2 Thm.prop_of);
   195 );
   164 );
   196 
   165 
   197 val trace_BEST_FIRST = Unsynchronized.ref false;
       
   198 
       
   199 (*For creating output sequence*)
   166 (*For creating output sequence*)
   200 fun some_of_list [] = NONE
   167 fun some_of_list [] = NONE
   201   | some_of_list (x :: l) = SOME (x, Seq.make (fn () => some_of_list l));
   168   | some_of_list (x :: l) = SOME (x, Seq.make (fn () => some_of_list l));
   202 
   169 
   203 (*Check for and delete duplicate proof states*)
   170 (*Check for and delete duplicate proof states*)
   208   else heap;
   175   else heap;
   209 
   176 
   210 (*Best-first search for a state that satisfies satp (incl initial state)
   177 (*Best-first search for a state that satisfies satp (incl initial state)
   211   Function sizef estimates size of problem remaining (smaller means better).
   178   Function sizef estimates size of problem remaining (smaller means better).
   212   tactic tac0 sets up the initial priority queue, while tac1 searches it. *)
   179   tactic tac0 sets up the initial priority queue, while tac1 searches it. *)
   213 fun THEN_BEST_FIRST tac0 (satp, sizef) tac1 =
   180 fun THEN_BEST_FIRST tac0 (satp, sizef) tac =
   214   let
   181   let
   215     val tac = tracify trace_BEST_FIRST tac1;
       
   216     fun pairsize th = (sizef th, th);
   182     fun pairsize th = (sizef th, th);
   217     fun bfs (news, nprf_heap) =
   183     fun bfs (news, nprf_heap) =
   218       (case List.partition satp news of
   184       (case List.partition satp news of
   219         ([], nonsats) => next (fold_rev Thm_Heap.insert (map pairsize nonsats) nprf_heap)
   185         ([], nonsats) => next (fold_rev Thm_Heap.insert (map pairsize nonsats) nprf_heap)
   220       | (sats, _)  => some_of_list sats)
   186       | (sats, _)  => some_of_list sats)
   221     and next nprf_heap =
   187     and next nprf_heap =
   222       if Thm_Heap.is_empty nprf_heap then NONE
   188       if Thm_Heap.is_empty nprf_heap then NONE
   223       else
   189       else
   224         let
   190         let
   225           val (n, prf) = Thm_Heap.min nprf_heap;
   191           val (n, prf) = Thm_Heap.min nprf_heap;
   226           val _ =
       
   227             if !trace_BEST_FIRST
       
   228             then tracing("state size = " ^ string_of_int n)
       
   229             else ();
       
   230         in
   192         in
   231           bfs (Seq.list_of (tac prf), delete_all_min prf (Thm_Heap.delete_min nprf_heap))
   193           bfs (Seq.list_of (tac prf), delete_all_min prf (Thm_Heap.delete_min nprf_heap))
   232         end;
   194         end;
   233     fun btac st = bfs (Seq.list_of (tac0 st), Thm_Heap.empty)
   195     fun btac st = bfs (Seq.list_of (tac0 st), Thm_Heap.empty)
   234   in traced_tac btac end;
   196   in fn st => Seq.make (fn () => btac st) end;
   235 
   197 
   236 (*Ordinary best-first search, with no initial tactic*)
   198 (*Ordinary best-first search, with no initial tactic*)
   237 val BEST_FIRST = THEN_BEST_FIRST all_tac;
   199 val BEST_FIRST = THEN_BEST_FIRST all_tac;
   238 
   200 
   239 (*Breadth-first search to satisfy satpred (including initial state)
   201 (*Breadth-first search to satisfy satpred (including initial state)
   269 
   231 
   270 (*For creating output sequence*)
   232 (*For creating output sequence*)
   271 fun some_of_list [] = NONE
   233 fun some_of_list [] = NONE
   272   | some_of_list (x :: xs) = SOME (x, Seq.make (fn () => some_of_list xs));
   234   | some_of_list (x :: xs) = SOME (x, Seq.make (fn () => some_of_list xs));
   273 
   235 
   274 val trace_ASTAR = Unsynchronized.ref false;
   236 fun THEN_ASTAR tac0 (satp, costf) tac =
   275 
   237   let
   276 fun THEN_ASTAR tac0 (satp, costf) tac1 =
       
   277   let
       
   278     val tf = tracify trace_ASTAR tac1;
       
   279     fun bfs (news, nprfs, level) =
   238     fun bfs (news, nprfs, level) =
   280       let fun cost thm = (level, costf level thm, thm) in
   239       let fun cost thm = (level, costf level thm, thm) in
   281         (case List.partition satp news of
   240         (case List.partition satp news of
   282           ([], nonsats) => next (fold_rev (insert_with_level o cost) nonsats nprfs)
   241           ([], nonsats) => next (fold_rev (insert_with_level o cost) nonsats nprfs)
   283         | (sats, _) => some_of_list sats)
   242         | (sats, _) => some_of_list sats)
   284       end
   243       end
   285     and next [] = NONE
   244     and next [] = NONE
   286       | next ((level, n, prf) :: nprfs) =
   245       | next ((level, n, prf) :: nprfs) = bfs (Seq.list_of (tac prf), nprfs, level + 1)
   287          (if !trace_ASTAR then
   246   in fn st => Seq.make (fn () => bfs (Seq.list_of (tac0 st), [], 0)) end;
   288           tracing ("level = " ^ string_of_int level ^
       
   289                    "  cost = " ^ string_of_int n ^
       
   290                    "  queue length =" ^ string_of_int (length nprfs))
       
   291           else ();
       
   292           bfs (Seq.list_of (tf prf), nprfs, level + 1))
       
   293     fun tf st = bfs (Seq.list_of (tac0 st), [], 0);
       
   294   in traced_tac tf end;
       
   295 
   247 
   296 (*Ordinary ASTAR, with no initial tactic*)
   248 (*Ordinary ASTAR, with no initial tactic*)
   297 val ASTAR = THEN_ASTAR all_tac;
   249 val ASTAR = THEN_ASTAR all_tac;
   298 
   250 
   299 end;
   251 end;