--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/src/Tools/Argo/argo_cdcl.ML Thu Sep 29 20:54:44 2016 +0200
@@ -0,0 +1,477 @@
+(* Title: Tools/Argo/argo_cdcl.ML
+ Author: Sascha Boehme
+
+Propositional satisfiability solver in the style of conflict-driven
+clause-learning (CDCL). It features:
+
+ * conflict analysis and clause learning based on the first unique implication point
+ * nonchronological backtracking
+ * dynamic variable ordering (VSIDS)
+ * restarting
+ * polarity caching
+ * propagation via two watched literals
+ * special propagation of binary clauses
+ * minimizing learned clauses
+ * support for external knowledge
+
+These features might be added:
+
+ * pruning of unnecessary learned clauses
+ * rebuilding the variable heap
+ * aligning the restart level with the decision heuristics: keep decisions that would
+ be recovered instead of backjumping to level 0
+
+The implementation is inspired by:
+
+ Niklas E'en and Niklas S"orensson. An Extensible SAT-solver. In Enrico
+ Giunchiglia and Armando Tacchella, editors, Theory and Applications of
+ Satisfiability Testing. Volume 2919 of Lecture Notes in Computer
+ Science, pages 502-518. Springer, 2003.
+
+ Niklas S"orensson and Armin Biere. Minimizing Learned Clauses. In
+ Oliver Kullmann, editor, Theory and Applications of Satisfiability
+ Testing. Volume 5584 of Lecture Notes in Computer Science,
+ pages 237-243. Springer, 2009.
+*)
+
+signature ARGO_CDCL =
+sig
+ (* types *)
+ type 'a explain = Argo_Lit.literal -> 'a -> Argo_Cls.clause * 'a
+
+ (* context *)
+ type context
+ val context: context
+ val assignment_of: context -> Argo_Lit.literal -> bool option
+
+ (* enriching the context *)
+ val add_atom: Argo_Term.term -> context -> context
+ val add_axiom: Argo_Cls.clause -> context -> int * context
+
+ (* main operations *)
+ val assume: 'a explain -> Argo_Lit.literal -> context -> 'a ->
+ Argo_Cls.clause option * context * 'a
+ val propagate: context -> Argo_Common.literal Argo_Common.implied * context
+ val decide: context -> context option
+ val analyze: 'a explain -> Argo_Cls.clause -> context -> 'a -> int * context * 'a
+ val restart: context -> int * context
+end
+
+structure Argo_Cdcl: ARGO_CDCL =
+struct
+
+(* basic types and operations *)
+
+type 'a explain = Argo_Lit.literal -> 'a -> Argo_Cls.clause * 'a
+
+datatype reason =
+ Level0 of Argo_Proof.proof |
+ Decided of int * int * (bool * reason) Argo_Termtab.table |
+ Implied of int * int * (Argo_Lit.literal * reason) list * Argo_Proof.proof |
+ External of int
+
+fun level_of (Level0 _) = 0
+ | level_of (Decided (l, _, _)) = l
+ | level_of (Implied (l, _, _, _)) = l
+ | level_of (External l) = l
+
+type justified = Argo_Lit.literal * reason
+
+type watches = Argo_Cls.clause list * Argo_Cls.clause list
+
+fun get_watches wts t = Argo_Termtab.lookup wts t
+fun map_watches f t wts = Argo_Termtab.map_default (t, ([], [])) f wts
+
+fun map_lit_watches f (Argo_Lit.Pos t) = map_watches (apsnd f) t
+ | map_lit_watches f (Argo_Lit.Neg t) = map_watches (apfst f) t
+
+fun watches_of wts (Argo_Lit.Pos t) = (case get_watches wts t of SOME (ws, _) => ws | NONE => [])
+ | watches_of wts (Argo_Lit.Neg t) = (case get_watches wts t of SOME (_, ws) => ws | NONE => [])
+
+fun attach cls lit = map_lit_watches (cons cls) lit
+fun detach cls lit = map_lit_watches (remove Argo_Cls.eq_clause cls) lit
+
+
+(* literal values *)
+
+fun raw_val_of vals lit = Argo_Termtab.lookup vals (Argo_Lit.term_of lit)
+
+fun val_of vals (Argo_Lit.Pos t) = Argo_Termtab.lookup vals t
+ | val_of vals (Argo_Lit.Neg t) = Option.map (apfst not) (Argo_Termtab.lookup vals t)
+
+fun value_of vals (Argo_Lit.Pos t) = Option.map fst (Argo_Termtab.lookup vals t)
+ | value_of vals (Argo_Lit.Neg t) = Option.map (not o fst) (Argo_Termtab.lookup vals t)
+
+fun justified vals lit = Option.map (pair lit o snd) (raw_val_of vals lit)
+fun the_reason_of vals lit = snd (the (raw_val_of vals lit))
+
+fun assign (Argo_Lit.Pos t) r = Argo_Termtab.update (t, (true, r))
+ | assign (Argo_Lit.Neg t) r = Argo_Termtab.update (t, (false, r))
+
+
+(* context *)
+
+type trail = int * justified list (* the trail height and the sequence of assigned literals *)
+
+type context = {
+ units: Argo_Common.literal list, (* the literals that await propagation *)
+ level: int, (* the decision level *)
+ trail: int * justified list, (* the trail height and the sequence of assigned literals *)
+ vals: (bool * reason) Argo_Termtab.table, (* mapping of terms to polarity and reason *)
+ wts: watches Argo_Termtab.table, (* clauses watched by terms *)
+ heap: Argo_Heap.heap, (* max-priority heap for decision heuristics *)
+ clss: Argo_Cls.table, (* information about clauses *)
+ prf: Argo_Proof.context} (* the proof context *)
+
+fun mk_context units level trail vals wts heap clss prf: context =
+ {units=units, level=level, trail=trail, vals=vals, wts=wts, heap=heap, clss=clss, prf=prf}
+
+val context =
+ mk_context [] 0 (0, []) Argo_Termtab.empty Argo_Termtab.empty Argo_Heap.heap
+ Argo_Cls.table Argo_Proof.cdcl_context
+
+fun drop_levels n (Decided (l, h, vals)) trail heap =
+ if l = n + 1 then ((h, trail), vals, heap) else drop_literal n trail heap
+ | drop_levels n _ tr heap = drop_literal n tr heap
+
+and drop_literal n ((lit, r) :: trail) heap = drop_levels n r trail (Argo_Heap.insert lit heap)
+ | drop_literal _ [] _ = raise Fail "bad trail"
+
+fun backjump_to new_level (cx as {level, trail=(_, tr), wts, heap, clss, prf, ...}: context) =
+ if new_level >= level then (0, cx)
+ else
+ let val (trail, vals, heap) = drop_literal (Integer.max 0 new_level) tr heap
+ in (level - new_level, mk_context [] new_level trail vals wts heap clss prf) end
+
+
+(* proofs *)
+
+fun tag_clause (lits, p) prf = Argo_Proof.mk_clause lits p prf |>> pair lits
+
+fun level0_unit_proof (lit, Level0 p') (p, prf) = Argo_Proof.mk_unit_res lit p p' prf
+ | level0_unit_proof _ _ = raise Fail "bad reason"
+
+fun level0_unit_proofs lrs p prf = fold level0_unit_proof lrs (p, prf)
+
+fun unsat ({vals, prf, ...}: context) (lits, p) =
+ let val lrs = map (fn lit => (lit, the_reason_of vals lit)) lits
+ in Argo_Proof.unsat (fst (level0_unit_proofs lrs p prf)) end
+
+
+(* literal operations *)
+
+fun push lit p reason prf ({units, level, trail=(h, tr), vals, wts, heap, clss, ...}: context) =
+ let val vals = assign lit reason vals
+ in mk_context ((lit, p) :: units) level (h + 1, (lit, reason) :: tr) vals wts heap clss prf end
+
+fun push_level0 lit p lrs (cx as {prf, ...}: context) =
+ let val (p, prf) = level0_unit_proofs lrs p prf
+ in push lit (SOME p) (Level0 p) prf cx end
+
+fun push_implied lit p lrs (cx as {level, trail=(h, _), prf, ...}: context) =
+ if level > 0 then push lit NONE (Implied (level, h, lrs, p)) prf cx
+ else push_level0 lit p lrs cx
+
+fun push_decided lit (cx as {level, trail=(h, _), vals, prf, ...}: context) =
+ push lit NONE (Decided (level, h, vals)) prf cx
+
+fun assignment_of ({vals, ...}: context) = value_of vals
+
+fun replace_watches old new cls ({units, level, trail, vals, wts, heap, clss, prf}: context) =
+ mk_context units level trail vals (attach cls new (detach cls old wts)) heap clss prf
+
+
+(* clause operations *)
+
+fun as_clause cls ({units, level, trail, vals, wts, heap, clss, prf}: context) =
+ let val (cls, prf) = tag_clause cls prf
+ in (cls, mk_context units level trail vals wts heap clss prf) end
+
+fun note_watches ([_, _], _) _ clss = clss
+ | note_watches cls lp clss = Argo_Cls.put_watches cls lp clss
+
+fun attach_clause lit1 lit2 (cls as (lits, _)) cx =
+ let
+ val {units, level, trail, vals, wts, heap, clss, prf}: context = cx
+ val wts = attach cls lit1 (attach cls lit2 wts)
+ val clss = note_watches cls (lit1, lit2) clss
+ in mk_context units level trail vals wts (fold Argo_Heap.count lits heap) clss prf end
+
+fun change_watches _ (false, _, _) cx = cx
+ | change_watches cls (true, l1, l2) ({units, level, trail, vals, wts, heap, clss, prf}: context) =
+ mk_context units level trail vals wts heap (Argo_Cls.put_watches cls (l1, l2) clss) prf
+
+fun add_asserting lit lit' (cls as (_, p)) lrs cx =
+ attach_clause lit lit' cls (push_implied lit p lrs cx)
+
+(*
+ When learning a non-unit clause, the context is backtracked to the highest decision level
+ of the assigned literals.
+*)
+
+fun learn_clause _ ([lit], p) cx = backjump_to 0 cx ||> push_level0 lit p []
+ | learn_clause lrs (cls as (lits, _)) cx =
+ let
+ fun max_level (l, r) (ll as (_, lvl)) = if level_of r > lvl then (l, level_of r) else ll
+ val (lit, lvl) = fold max_level lrs (hd lits, 0)
+ in backjump_to lvl cx ||> add_asserting (hd lits) lit cls lrs end
+
+(*
+ An axiom with one unassigned literal and all remaining literals being assigned to
+ false is asserting. An axiom with all literals assigned to false on level 0 makes the
+ context unsatisfiable. An axiom with all literals assigned to false on higher levels
+ causes backjumping before the highest level, and then the axiom might be asserting if
+ only one literal is unassigned on that level.
+*)
+
+fun min lit i NONE = SOME (lit, i)
+ | min lit i (SOME (lj as (_, j))) = SOME (if i < j then (lit, i) else lj)
+
+fun level_ord ((_, r1), (_, r2)) = int_ord (level_of r2, level_of r1)
+fun add_max lr lrs = Ord_List.insert level_ord lr lrs
+
+fun part [] [] t us fs = (t, us, fs)
+ | part (NONE :: vs) (l :: ls) t us fs = part vs ls t (l :: us) fs
+ | part (SOME (true, r) :: vs) (l :: ls) t us fs = part vs ls (min l (level_of r) t) us fs
+ | part (SOME (false, r) :: vs) (l :: ls) t us fs = part vs ls t us (add_max (l, r) fs)
+ | part _ _ _ _ _ = raise Fail "mismatch between values and literals"
+
+fun backjump_add (lit, r) (lit', r') cls lrs cx =
+ let
+ val add =
+ if level_of r = level_of r' then attach_clause lit lit' cls
+ else add_asserting lit lit' cls lrs
+ in backjump_to (level_of r - 1) cx ||> add end
+
+fun analyze_axiom vs (cls as (lits, p), cx) =
+ (case part vs lits NONE [] [] of
+ (SOME (lit, lvl), [], []) =>
+ if lvl > 0 then backjump_to 0 cx ||> push_implied lit p [] else (0, cx)
+ | (SOME (lit, lvl), [], (lit', _) :: _) => (0, cx |> (lvl > 0) ? attach_clause lit lit' cls)
+ | (SOME (lit, lvl), lit' :: _, _) => (0, cx |> (lvl > 0) ? attach_clause lit lit' cls)
+ | (NONE, [], (_, Level0 _) :: _) => unsat cx cls
+ | (NONE, [], [(lit, _)]) => backjump_to 0 cx ||> push_implied lit p []
+ | (NONE, [], lrs as (lr :: lr' :: _)) => backjump_add lr lr' cls lrs cx
+ | (NONE, [lit], []) => backjump_to 0 cx ||> push_implied lit p []
+ | (NONE, [lit], lrs as (lit', _) :: _) => (0, add_asserting lit lit' cls lrs cx)
+ | (NONE, lit1 :: lit2 :: _, _) => (0, attach_clause lit1 lit2 cls cx)
+ | _ => raise Fail "bad clause")
+
+
+(* enriching the context *)
+
+fun add_atom t ({units, level, trail, vals, wts, heap, clss, prf}: context) =
+ let val heap = Argo_Heap.insert (Argo_Lit.Pos t) heap
+ in mk_context units level trail vals wts heap clss prf end
+
+fun add_axiom ([], p) _ = Argo_Proof.unsat p
+ | add_axiom (cls as (lits, _)) (cx as {vals, ...}: context) =
+ if has_duplicates Argo_Lit.eq_lit lits then raise Fail "clause with duplicate literals"
+ else if has_duplicates Argo_Lit.dual_lit lits then (0, cx)
+ else analyze_axiom (map (val_of vals) lits) (as_clause cls cx)
+
+
+(* external knowledge *)
+
+fun assume explain lit (cx as {level, vals, prf, ...}: context) x =
+ (case value_of vals lit of
+ SOME true => (NONE, cx, x)
+ | SOME false =>
+ let val (cls, x) = explain lit x
+ in if level = 0 then unsat cx cls else (SOME cls, cx, x) end
+ | NONE =>
+ if level = 0 then
+ let val ((lits, p), x) = explain lit x
+ in (NONE, push_level0 lit p (map_filter (justified vals) lits) cx, x) end
+ else (NONE, push lit NONE (External level) prf cx, x))
+
+
+(* propagation *)
+
+exception CONFLICT of Argo_Cls.clause * context
+
+fun order_lits_by lit (l1, l2) =
+ if Argo_Lit.eq_id (l1, lit) then (true, l2, l1) else (false, l1, l2)
+
+fun prop_binary (_, implied_lit, other_lit) (cls as (_, p)) (cx as {level, vals, ...}: context) =
+ (case value_of vals implied_lit of
+ NONE => push_implied implied_lit p [(other_lit, the_reason_of vals other_lit)] cx
+ | SOME true => cx
+ | SOME false => if level = 0 then unsat cx cls else raise CONFLICT (cls, cx))
+
+datatype next = Lit of Argo_Lit.literal | None of justified list
+
+fun with_non_false f l (SOME (false, r)) lrs = f ((l, r) :: lrs)
+ | with_non_false _ l _ _ = Lit l
+
+fun first_non_false _ _ [] lrs = None lrs
+ | first_non_false vals lit (l :: ls) lrs =
+ if Argo_Lit.eq_lit (l, lit) then first_non_false vals lit ls lrs
+ else with_non_false (first_non_false vals lit ls) l (val_of vals l) lrs
+
+fun prop_nary (lp as (_, lit1, lit2)) (cls as (lits, p)) (cx as {level, vals, ...}: context) =
+ let val v = value_of vals lit1
+ in
+ if v = SOME true then change_watches cls lp cx
+ else
+ (case first_non_false vals lit1 lits [] of
+ Lit lit2' => change_watches cls (true, lit1, lit2') (replace_watches lit2 lit2' cls cx)
+ | None lrs =>
+ if v = NONE then push_implied lit1 p lrs (change_watches cls lp cx)
+ else if level = 0 then unsat cx cls
+ else raise CONFLICT (cls, change_watches cls lp cx))
+ end
+
+fun prop_cls lit (cls as ([l1, l2], _)) cx = prop_binary (order_lits_by lit (l1, l2)) cls cx
+ | prop_cls lit cls (cx as {clss, ...}: context) =
+ prop_nary (order_lits_by lit (Argo_Cls.get_watches clss cls)) cls cx
+
+fun prop_lit (lp as (lit, _)) (lps, cx as {wts, ...}: context) =
+ (lp :: lps, fold (prop_cls lit) (watches_of wts lit) cx)
+
+fun prop lps (cx as {units=[], ...}: context) = (Argo_Common.Implied (rev lps), cx)
+ | prop lps ({units, level, trail, vals, wts, heap, clss, prf}: context) =
+ fold_rev prop_lit units (lps, mk_context [] level trail vals wts heap clss prf) |-> prop
+
+fun propagate cx = prop [] cx
+ handle CONFLICT (cls, cx) => (Argo_Common.Conflict cls, cx)
+
+
+(* decisions *)
+
+(*
+ Decisions are based on an activity heuristics. The most active variable that is
+ still unassigned is chosen.
+*)
+
+fun decide ({units, level, trail, vals, wts, heap, clss, prf}: context) =
+ let
+ fun check NONE = NONE
+ | check (SOME (lit, heap)) =
+ if Argo_Termtab.defined vals (Argo_Lit.term_of lit) then check (Argo_Heap.extract heap)
+ else SOME (push_decided lit (mk_context units (level + 1) trail vals wts heap clss prf))
+ in check (Argo_Heap.extract heap) end
+
+
+(* conflict analysis and clause learning *)
+
+(*
+ Learned clauses often contain literals that are redundant, because they are
+ subsumed by other literals of the clause. By analyzing the implication graph beyond
+ the unique implication point, such redundant literals can be identified and hence
+ removed from the learned clause. Only literals occurring in the learned clause and
+ their reasons need to be analyzed.
+*)
+
+exception ESSENTIAL of unit
+
+fun history_ord ((h1, lit1, _), (h2, lit2, _)) =
+ if h1 < 0 andalso h2 < 0 then int_ord (apply2 Argo_Lit.signed_id_of (lit1, lit2))
+ else int_ord (h2, h1)
+
+fun rec_redundant stop (lit, Implied (lvl, h, lrs, p)) lps =
+ if stop lit lvl then lps
+ else fold (rec_redundant stop) lrs ((h, lit, p) :: lps)
+ | rec_redundant stop (lit, Decided (lvl, _, _)) lps =
+ if stop lit lvl then lps
+ else raise ESSENTIAL ()
+ | rec_redundant _ (lit, Level0 p) lps = ((~1, lit, p) :: lps)
+ | rec_redundant _ _ _ = raise ESSENTIAL ()
+
+fun redundant stop (lr as (lit, Implied (_, h, lrs, p))) (lps, essential_lrs) = (
+ (fold (rec_redundant stop) lrs ((h, lit, p) :: lps), essential_lrs)
+ handle ESSENTIAL () => (lps, lr :: essential_lrs))
+ | redundant _ lr (lps, essential_lrs) = (lps, lr :: essential_lrs)
+
+fun resolve_step (_, l, p') (p, prf) = Argo_Proof.mk_unit_res l p p' prf
+
+fun reduce lrs p prf =
+ let
+ val lits = map fst lrs
+ val levels = fold (insert (op =) o level_of o snd) lrs []
+ fun stop lit level =
+ if member Argo_Lit.eq_lit lits lit then true
+ else if member (op =) levels level then false
+ else raise ESSENTIAL ()
+
+ val (lps, lrs) = fold (redundant stop) lrs ([], [])
+ in (lrs, fold resolve_step (sort_distinct history_ord lps) (p, prf)) end
+
+(*
+ Literals that are candidates for the learned lemma are marked and unmarked while
+ traversing backwards through the trail. The last remaining marked literal is the first
+ unique implication point.
+*)
+
+fun unmark lit ms = remove Argo_Lit.eq_id lit ms
+fun marked ms lit = member Argo_Lit.eq_id ms lit
+
+(*
+ Whenever an implication is recorded, the reason for the false literals of the
+ asserting clause are known. It is reasonable to store this justification list as part
+ of the implication reason. Consequently, the implementation of conflict analysis can
+ benefit from this information, which does not need to be re-computed here.
+*)
+
+fun justification_for _ _ _ (Implied (_, _, lrs, p)) x = (lrs, p, x)
+ | justification_for explain vals lit (External _) x =
+ let val ((lits, p), x) = explain lit x
+ in (map_filter (justified vals) lits, p, x) end
+ | justification_for _ _ _ _ _ = raise Fail "bad reason"
+
+fun first_lit pred ((lr as (lit, _)) :: lrs) = if pred lit then (lr, lrs) else first_lit pred lrs
+ | first_lit _ _ = raise Empty
+
+(*
+ Beginning from the conflicting clause, the implication graph is traversed to the first
+ unique implication point. This breadth-first search is controlled by the topological order of
+ the trail, which is traversed backwards. While traversing through the trail, the conflict
+ literals of lower levels are collected to form the conflict lemma together with the unique
+ implication point. Conflict literals assigned on level 0 are excluded from the conflict lemma.
+ Conflict literals assigned on the current level are candidates for the first unique
+ implication point.
+*)
+
+fun analyze explain cls (cx as {level, trail, vals, wts, heap, clss, prf, ...}: context) x =
+ let
+ fun from_clause [] trail ms lrs h p prf x =
+ from_trail (first_lit (marked ms) trail) ms lrs h p prf x
+ | from_clause ((lit, r) :: clause_lrs) trail ms lrs h p prf x =
+ from_reason r lit clause_lrs trail ms lrs h p prf x
+
+ and from_reason (Level0 p') lit clause_lrs trail ms lrs h p prf x =
+ let val (p, prf) = Argo_Proof.mk_unit_res lit p p' prf
+ in from_clause clause_lrs trail ms lrs h p prf x end
+ | from_reason r lit clause_lrs trail ms lrs h p prf x =
+ if level_of r = level then
+ if marked ms lit then from_clause clause_lrs trail ms lrs h p prf x
+ else from_clause clause_lrs trail (lit :: ms) lrs (Argo_Heap.increase lit h) p prf x
+ else
+ let
+ val (lrs, h) =
+ if AList.defined Argo_Lit.eq_id lrs lit then (lrs, h)
+ else ((lit, r) :: lrs, Argo_Heap.increase lit h)
+ in from_clause clause_lrs trail ms lrs h p prf x end
+
+ and from_trail ((lit, _), _) [_] lrs h p prf x =
+ let val (lrs, (p, prf)) = reduce lrs p prf
+ in (Argo_Lit.negate lit :: map fst lrs, lrs, h, p, prf, x) end
+ | from_trail ((lit, r), trail) ms lrs h p prf x =
+ let
+ val (clause_lrs, p', x) = justification_for explain vals lit r x
+ val (p, prf) = Argo_Proof.mk_unit_res lit p' p prf
+ in from_clause clause_lrs trail (unmark lit ms) lrs h p prf x end
+
+ val (ls, p) = cls
+ val lrs = if level = 0 then unsat cx cls else map (fn l => (l, the_reason_of vals l)) ls
+ val (lits, lrs, heap, p, prf, x) = from_clause lrs (snd trail) [] [] heap p prf x
+ val heap = Argo_Heap.decay heap
+ val (levels, cx) = learn_clause lrs (lits, p) (mk_context [] level trail vals wts heap clss prf)
+ in (levels, cx, x) end
+
+
+(* restarting *)
+
+fun restart cx = backjump_to 0 cx
+
+end