src/HOL/Tools/sat_solver.ML
changeset 56815 848d507584db
parent 56147 9589605bcf41
child 56845 691da43fbdd4
     1.1 --- a/src/HOL/Tools/sat_solver.ML	Thu May 01 22:41:03 2014 +0200
     1.2 +++ b/src/HOL/Tools/sat_solver.ML	Thu May 01 22:56:59 2014 +0200
     1.3 @@ -503,6 +503,201 @@
     1.4  end;
     1.5  
     1.6  (* ------------------------------------------------------------------------- *)
     1.7 +(* Internal SAT solver, available as 'SatSolver.invoke_solver "dpll_p"' --   *)
     1.8 +(* a simple, slightly more efficient implementation of the DPLL algorithm    *)
     1.9 +(* (cf. L. Zhang, S. Malik: "The Quest for Efficient Boolean Satisfiability  *)
    1.10 +(* Solvers", July 2002, Fig. 2). In contrast to the other two ML solvers     *)
    1.11 +(* above, this solver produces proof traces. *)
    1.12 +(* ------------------------------------------------------------------------- *)
    1.13 +
    1.14 +let
    1.15 +  type clause = int list * int
    1.16 +  type value = bool option
    1.17 +  datatype reason = Decided | Implied of clause | Level0 of int
    1.18 +  type variable = bool option * reason * int * int
    1.19 +  type proofs = int * int list Inttab.table
    1.20 +  type state =
    1.21 +    int * int list * variable Inttab.table * clause list Inttab.table * proofs
    1.22 +  exception CONFLICT of clause * state
    1.23 +  exception UNSAT of clause * state
    1.24 +
    1.25 +  fun neg i = ~i
    1.26 +
    1.27 +  fun lit_value lit value = if lit > 0 then value else Option.map not value
    1.28 +
    1.29 +  fun var_of vars lit: variable = the (Inttab.lookup vars (abs lit))
    1.30 +  fun value_of vars lit = lit_value lit (#1 (var_of vars lit))
    1.31 +  fun reason_of vars lit = #2 (var_of vars lit)
    1.32 +  fun level_of vars lit = #3 (var_of vars lit)
    1.33 +
    1.34 +  fun is_true vars lit = (value_of vars lit = SOME true)
    1.35 +  fun is_false vars lit = (value_of vars lit = SOME false)
    1.36 +  fun is_unassigned vars lit = (value_of vars lit = NONE)
    1.37 +  fun assignment_of vars lit = the_default NONE (try (value_of vars) lit)
    1.38 +
    1.39 +  fun put_var value reason level (_, _, _, rank) = (value, reason, level, rank)
    1.40 +  fun incr_rank (value, reason, level, rank) = (value, reason, level, rank + 1)
    1.41 +  fun update_var lit f = Inttab.map_entry (abs lit) f
    1.42 +  fun add_var lit = Inttab.update (abs lit, (NONE, Decided, ~1, 0))
    1.43 +
    1.44 +  fun assign lit r l = update_var lit (put_var (SOME (lit > 0)) r l)
    1.45 +  fun unassign lit = update_var lit (put_var NONE Decided ~1)
    1.46 +
    1.47 +  fun add_proof [] (idx, ptab) = (idx, (idx + 1, ptab))
    1.48 +    | add_proof ps (idx, ptab) = (idx, (idx + 1, Inttab.update (idx, ps) ptab))
    1.49 +
    1.50 +  fun level0_proof_of (Level0 idx) = SOME idx
    1.51 +    | level0_proof_of _ = NONE
    1.52 +
    1.53 +  fun level0_proofs_of vars = map_filter (level0_proof_of o reason_of vars)
    1.54 +  fun prems_of vars (lits, p) = p :: level0_proofs_of vars lits
    1.55 +  fun mk_proof vars cls proofs = add_proof (prems_of vars cls) proofs
    1.56 +
    1.57 +  fun push lit cls (level, trail, vars, clss, proofs) =
    1.58 +    let
    1.59 +      val (reason, proofs) =
    1.60 +        if level = 0 then apfst Level0 (mk_proof vars cls proofs)
    1.61 +        else (Implied cls, proofs)
    1.62 +    in (level, lit :: trail, assign lit reason level vars, clss, proofs) end
    1.63 +
    1.64 +  fun push_decided lit (level, trail, vars, clss, proofs) =
    1.65 +    let val vars' = assign lit Decided (level + 1) vars
    1.66 +    in (level + 1, lit :: 0 :: trail, vars', clss, proofs) end
    1.67 +
    1.68 +  fun prop (cls as (lits, _)) (cx as (units, state as (level, _, vars, _, _))) =
    1.69 +    if exists (is_true vars) lits then cx
    1.70 +    else if forall (is_false vars) lits then
    1.71 +      if level = 0 then raise UNSAT (cls, state)
    1.72 +      else raise CONFLICT (cls, state)
    1.73 +    else
    1.74 +      (case filter (is_unassigned vars) lits of
    1.75 +        [lit] => (lit :: units, push lit cls state)
    1.76 +      | _ => cx)
    1.77 +
    1.78 +  fun propagate units (state as (_, _, _, clss, _)) =
    1.79 +    (case fold (fold prop o Inttab.lookup_list clss) units ([], state) of
    1.80 +      ([], state') => (NONE, state')
    1.81 +    | (units', state') => propagate units' state')
    1.82 +    handle CONFLICT (cls, state') => (SOME cls, state')
    1.83 +
    1.84 +  fun max_unassigned (v, (NONE, _, _, rank)) (x as (_, r)) =
    1.85 +        if rank > r then (SOME v, rank) else x
    1.86 +    | max_unassigned _  x = x
    1.87 +
    1.88 +  fun decide (state as (_, _, vars, _, _)) =
    1.89 +    (case Inttab.fold max_unassigned vars (NONE, 0) of
    1.90 +      (SOME lit, _) => SOME (lit, push_decided lit state)
    1.91 +    | (NONE, _) => NONE)
    1.92 +
    1.93 +  fun mark lit = Inttab.update (abs lit, true)
    1.94 +  fun marked ms lit = the_default false (Inttab.lookup ms (abs lit))
    1.95 +  fun ignore l ms lit = ((lit = l) orelse marked ms lit)
    1.96 +
    1.97 +  fun first_lit _ [] = raise Empty
    1.98 +    | first_lit _ (0 :: _) = raise Empty
    1.99 +    | first_lit pred (lit :: lits) =
   1.100 +        if pred lit then (lit, lits) else first_lit pred lits
   1.101 +
   1.102 +  fun reason_cls_of vars lit =
   1.103 +    (case reason_of vars lit of
   1.104 +      Implied cls => cls
   1.105 +    | _ => raise Option)
   1.106 +
   1.107 +  fun analyze conflicting_cls (level, trail, vars, _, _) =
   1.108 +    let
   1.109 +      fun back i lit (lits, p) trail ms ls ps =
   1.110 +        let
   1.111 +          val (lits0, lits') = List.partition (equal 0 o level_of vars) lits
   1.112 +          val lits1 = filter_out (ignore lit ms) lits'
   1.113 +          val lits2 = filter_out (equal level o level_of vars) lits1
   1.114 +          val i' = length lits1 - length lits2 + i
   1.115 +          val ms' = fold mark lits1 ms
   1.116 +          val ls' = lits2 @ ls
   1.117 +          val ps' = level0_proofs_of vars lits0 @ (p :: ps)
   1.118 +          val (lit', trail') = first_lit (marked ms') trail
   1.119 +        in 
   1.120 +          if i' = 1 then (neg lit', ls', rev ps')
   1.121 +          else back (i' - 1) lit' (reason_cls_of vars lit') trail' ms' ls' ps'
   1.122 +        end
   1.123 +    in back 0 0 conflicting_cls trail Inttab.empty [] [] end
   1.124 +
   1.125 +  fun keep_clause (cls as (lits, _)) (level, trail, vars, clss, proofs) =
   1.126 +    let
   1.127 +      val vars' = fold (fn lit => update_var lit incr_rank) lits vars
   1.128 +      val clss' = fold (fn lit => Inttab.cons_list (neg lit, cls)) lits clss
   1.129 +    in (level, trail, vars', clss', proofs) end
   1.130 +
   1.131 +  fun learn (cls as (lits, _)) = (length lits <= 2) ? keep_clause cls
   1.132 +
   1.133 +  fun backjump _ (state as (_, [], _, _, _)) = state 
   1.134 +    | backjump i (level, 0 :: trail, vars, clss, proofs) =
   1.135 +        (level - 1, trail, vars, clss, proofs) |> (i > 1) ? backjump (i - 1)
   1.136 +    | backjump i (level, lit :: trail, vars, clss, proofs) =
   1.137 +        backjump i (level, trail, unassign lit vars, clss, proofs)
   1.138 +
   1.139 +  fun search units state =
   1.140 +    (case propagate units state of
   1.141 +      (NONE, state' as (_, _, vars, _, _)) =>
   1.142 +        (case decide state' of
   1.143 +          NONE => SatSolver.SATISFIABLE (assignment_of vars)
   1.144 +        | SOME (lit, state'') => search [lit] state'')
   1.145 +    | (SOME conflicting_cls, state' as (level, trail, vars, clss, proofs)) =>
   1.146 +        let 
   1.147 +          val (lit, lits, ps) = analyze conflicting_cls state'
   1.148 +          val (idx, proofs') = add_proof ps proofs
   1.149 +          val cls = (lit :: lits, idx)
   1.150 +        in
   1.151 +          (level, trail, vars, clss, proofs')
   1.152 +          |> backjump (level - fold (Integer.max o level_of vars) lits 0)
   1.153 +          |> learn cls
   1.154 +          |> push lit cls
   1.155 +          |> search [lit]
   1.156 +        end)
   1.157 +
   1.158 +  fun has_opposing_lits [] = false
   1.159 +    | has_opposing_lits (lit :: lits) =
   1.160 +        member (op =) lits (neg lit) orelse has_opposing_lits lits
   1.161 +
   1.162 +  fun add_clause (cls as ([_], _)) (units, state) =
   1.163 +        let val (units', state') = prop cls (units, state)
   1.164 +        in (units', state') end
   1.165 +    | add_clause (cls as (lits, _)) (cx as (units, state)) =
   1.166 +        if has_opposing_lits lits then cx
   1.167 +        else (units, keep_clause cls state)
   1.168 +
   1.169 +  fun mk_clause lits proofs =
   1.170 +    apfst (pair (distinct (op =) lits)) (add_proof [] proofs)
   1.171 +
   1.172 +  fun solve litss =
   1.173 +    let
   1.174 +      val (clss, proofs) = fold_map mk_clause litss (0, Inttab.empty)
   1.175 +      val vars = fold (fold add_var) litss Inttab.empty
   1.176 +      val state = (0, [], vars, Inttab.empty, proofs)
   1.177 +    in uncurry search (fold add_clause clss ([], state)) end
   1.178 +    handle UNSAT (conflicting_cls, (_, _, vars, _, proofs)) =>
   1.179 +      let val (idx, (_, ptab)) = mk_proof vars conflicting_cls proofs
   1.180 +      in SatSolver.UNSATISFIABLE (SOME (ptab, idx)) end
   1.181 +
   1.182 +  fun variable_of (Prop_Logic.BoolVar 0) = error "bad propositional variable"
   1.183 +    | variable_of (Prop_Logic.BoolVar i) = i
   1.184 +    | variable_of _ = error "expected formula in CNF"
   1.185 +  fun literal_of (Prop_Logic.Not fm) = neg (variable_of fm)
   1.186 +    | literal_of fm = variable_of fm
   1.187 +  fun clause_of (Prop_Logic.Or (fm1, fm2)) = clause_of fm1 @ clause_of fm2
   1.188 +    | clause_of fm = [literal_of fm]
   1.189 +  fun clauses_of (Prop_Logic.And (fm1, fm2)) = clauses_of fm1 @ clauses_of fm2
   1.190 +    | clauses_of Prop_Logic.True = [[1, ~1]]
   1.191 +    | clauses_of Prop_Logic.False = [[1], [~1]]
   1.192 +    | clauses_of fm = [clause_of fm]
   1.193 +
   1.194 +  fun dpll_solver fm =
   1.195 +    let val fm' = if Prop_Logic.is_cnf fm then fm else Prop_Logic.defcnf fm
   1.196 +    in solve (clauses_of fm') end
   1.197 +in
   1.198 +  SatSolver.add_solver ("dpll_p", dpll_solver)
   1.199 +end;
   1.200 +
   1.201 +(* ------------------------------------------------------------------------- *)
   1.202  (* Internal SAT solver, available as 'SatSolver.invoke_solver "auto"': uses  *)
   1.203  (* the last installed solver (other than "auto" itself) that does not raise  *)
   1.204  (* 'NOT_CONFIGURED'.  (However, the solver may return 'UNKNOWN'.)            *)