src/HOL/Tools/sat_solver.ML
changeset 56815 848d507584db
parent 56147 9589605bcf41
child 56845 691da43fbdd4
equal deleted inserted replaced
56814:eb8f2a5a57ad 56815:848d507584db
   498       | NONE            => SatSolver.UNSATISFIABLE NONE
   498       | NONE            => SatSolver.UNSATISFIABLE NONE
   499     end
   499     end
   500   end  (* local *)
   500   end  (* local *)
   501 in
   501 in
   502   SatSolver.add_solver ("dpll", dpll_solver)
   502   SatSolver.add_solver ("dpll", dpll_solver)
       
   503 end;
       
   504 
       
   505 (* ------------------------------------------------------------------------- *)
       
   506 (* Internal SAT solver, available as 'SatSolver.invoke_solver "dpll_p"' --   *)
       
   507 (* a simple, slightly more efficient implementation of the DPLL algorithm    *)
       
   508 (* (cf. L. Zhang, S. Malik: "The Quest for Efficient Boolean Satisfiability  *)
       
   509 (* Solvers", July 2002, Fig. 2). In contrast to the other two ML solvers     *)
       
   510 (* above, this solver produces proof traces. *)
       
   511 (* ------------------------------------------------------------------------- *)
       
   512 
       
   513 let
       
   514   type clause = int list * int
       
   515   type value = bool option
       
   516   datatype reason = Decided | Implied of clause | Level0 of int
       
   517   type variable = bool option * reason * int * int
       
   518   type proofs = int * int list Inttab.table
       
   519   type state =
       
   520     int * int list * variable Inttab.table * clause list Inttab.table * proofs
       
   521   exception CONFLICT of clause * state
       
   522   exception UNSAT of clause * state
       
   523 
       
   524   fun neg i = ~i
       
   525 
       
   526   fun lit_value lit value = if lit > 0 then value else Option.map not value
       
   527 
       
   528   fun var_of vars lit: variable = the (Inttab.lookup vars (abs lit))
       
   529   fun value_of vars lit = lit_value lit (#1 (var_of vars lit))
       
   530   fun reason_of vars lit = #2 (var_of vars lit)
       
   531   fun level_of vars lit = #3 (var_of vars lit)
       
   532 
       
   533   fun is_true vars lit = (value_of vars lit = SOME true)
       
   534   fun is_false vars lit = (value_of vars lit = SOME false)
       
   535   fun is_unassigned vars lit = (value_of vars lit = NONE)
       
   536   fun assignment_of vars lit = the_default NONE (try (value_of vars) lit)
       
   537 
       
   538   fun put_var value reason level (_, _, _, rank) = (value, reason, level, rank)
       
   539   fun incr_rank (value, reason, level, rank) = (value, reason, level, rank + 1)
       
   540   fun update_var lit f = Inttab.map_entry (abs lit) f
       
   541   fun add_var lit = Inttab.update (abs lit, (NONE, Decided, ~1, 0))
       
   542 
       
   543   fun assign lit r l = update_var lit (put_var (SOME (lit > 0)) r l)
       
   544   fun unassign lit = update_var lit (put_var NONE Decided ~1)
       
   545 
       
   546   fun add_proof [] (idx, ptab) = (idx, (idx + 1, ptab))
       
   547     | add_proof ps (idx, ptab) = (idx, (idx + 1, Inttab.update (idx, ps) ptab))
       
   548 
       
   549   fun level0_proof_of (Level0 idx) = SOME idx
       
   550     | level0_proof_of _ = NONE
       
   551 
       
   552   fun level0_proofs_of vars = map_filter (level0_proof_of o reason_of vars)
       
   553   fun prems_of vars (lits, p) = p :: level0_proofs_of vars lits
       
   554   fun mk_proof vars cls proofs = add_proof (prems_of vars cls) proofs
       
   555 
       
   556   fun push lit cls (level, trail, vars, clss, proofs) =
       
   557     let
       
   558       val (reason, proofs) =
       
   559         if level = 0 then apfst Level0 (mk_proof vars cls proofs)
       
   560         else (Implied cls, proofs)
       
   561     in (level, lit :: trail, assign lit reason level vars, clss, proofs) end
       
   562 
       
   563   fun push_decided lit (level, trail, vars, clss, proofs) =
       
   564     let val vars' = assign lit Decided (level + 1) vars
       
   565     in (level + 1, lit :: 0 :: trail, vars', clss, proofs) end
       
   566 
       
   567   fun prop (cls as (lits, _)) (cx as (units, state as (level, _, vars, _, _))) =
       
   568     if exists (is_true vars) lits then cx
       
   569     else if forall (is_false vars) lits then
       
   570       if level = 0 then raise UNSAT (cls, state)
       
   571       else raise CONFLICT (cls, state)
       
   572     else
       
   573       (case filter (is_unassigned vars) lits of
       
   574         [lit] => (lit :: units, push lit cls state)
       
   575       | _ => cx)
       
   576 
       
   577   fun propagate units (state as (_, _, _, clss, _)) =
       
   578     (case fold (fold prop o Inttab.lookup_list clss) units ([], state) of
       
   579       ([], state') => (NONE, state')
       
   580     | (units', state') => propagate units' state')
       
   581     handle CONFLICT (cls, state') => (SOME cls, state')
       
   582 
       
   583   fun max_unassigned (v, (NONE, _, _, rank)) (x as (_, r)) =
       
   584         if rank > r then (SOME v, rank) else x
       
   585     | max_unassigned _  x = x
       
   586 
       
   587   fun decide (state as (_, _, vars, _, _)) =
       
   588     (case Inttab.fold max_unassigned vars (NONE, 0) of
       
   589       (SOME lit, _) => SOME (lit, push_decided lit state)
       
   590     | (NONE, _) => NONE)
       
   591 
       
   592   fun mark lit = Inttab.update (abs lit, true)
       
   593   fun marked ms lit = the_default false (Inttab.lookup ms (abs lit))
       
   594   fun ignore l ms lit = ((lit = l) orelse marked ms lit)
       
   595 
       
   596   fun first_lit _ [] = raise Empty
       
   597     | first_lit _ (0 :: _) = raise Empty
       
   598     | first_lit pred (lit :: lits) =
       
   599         if pred lit then (lit, lits) else first_lit pred lits
       
   600 
       
   601   fun reason_cls_of vars lit =
       
   602     (case reason_of vars lit of
       
   603       Implied cls => cls
       
   604     | _ => raise Option)
       
   605 
       
   606   fun analyze conflicting_cls (level, trail, vars, _, _) =
       
   607     let
       
   608       fun back i lit (lits, p) trail ms ls ps =
       
   609         let
       
   610           val (lits0, lits') = List.partition (equal 0 o level_of vars) lits
       
   611           val lits1 = filter_out (ignore lit ms) lits'
       
   612           val lits2 = filter_out (equal level o level_of vars) lits1
       
   613           val i' = length lits1 - length lits2 + i
       
   614           val ms' = fold mark lits1 ms
       
   615           val ls' = lits2 @ ls
       
   616           val ps' = level0_proofs_of vars lits0 @ (p :: ps)
       
   617           val (lit', trail') = first_lit (marked ms') trail
       
   618         in 
       
   619           if i' = 1 then (neg lit', ls', rev ps')
       
   620           else back (i' - 1) lit' (reason_cls_of vars lit') trail' ms' ls' ps'
       
   621         end
       
   622     in back 0 0 conflicting_cls trail Inttab.empty [] [] end
       
   623 
       
   624   fun keep_clause (cls as (lits, _)) (level, trail, vars, clss, proofs) =
       
   625     let
       
   626       val vars' = fold (fn lit => update_var lit incr_rank) lits vars
       
   627       val clss' = fold (fn lit => Inttab.cons_list (neg lit, cls)) lits clss
       
   628     in (level, trail, vars', clss', proofs) end
       
   629 
       
   630   fun learn (cls as (lits, _)) = (length lits <= 2) ? keep_clause cls
       
   631 
       
   632   fun backjump _ (state as (_, [], _, _, _)) = state 
       
   633     | backjump i (level, 0 :: trail, vars, clss, proofs) =
       
   634         (level - 1, trail, vars, clss, proofs) |> (i > 1) ? backjump (i - 1)
       
   635     | backjump i (level, lit :: trail, vars, clss, proofs) =
       
   636         backjump i (level, trail, unassign lit vars, clss, proofs)
       
   637 
       
   638   fun search units state =
       
   639     (case propagate units state of
       
   640       (NONE, state' as (_, _, vars, _, _)) =>
       
   641         (case decide state' of
       
   642           NONE => SatSolver.SATISFIABLE (assignment_of vars)
       
   643         | SOME (lit, state'') => search [lit] state'')
       
   644     | (SOME conflicting_cls, state' as (level, trail, vars, clss, proofs)) =>
       
   645         let 
       
   646           val (lit, lits, ps) = analyze conflicting_cls state'
       
   647           val (idx, proofs') = add_proof ps proofs
       
   648           val cls = (lit :: lits, idx)
       
   649         in
       
   650           (level, trail, vars, clss, proofs')
       
   651           |> backjump (level - fold (Integer.max o level_of vars) lits 0)
       
   652           |> learn cls
       
   653           |> push lit cls
       
   654           |> search [lit]
       
   655         end)
       
   656 
       
   657   fun has_opposing_lits [] = false
       
   658     | has_opposing_lits (lit :: lits) =
       
   659         member (op =) lits (neg lit) orelse has_opposing_lits lits
       
   660 
       
   661   fun add_clause (cls as ([_], _)) (units, state) =
       
   662         let val (units', state') = prop cls (units, state)
       
   663         in (units', state') end
       
   664     | add_clause (cls as (lits, _)) (cx as (units, state)) =
       
   665         if has_opposing_lits lits then cx
       
   666         else (units, keep_clause cls state)
       
   667 
       
   668   fun mk_clause lits proofs =
       
   669     apfst (pair (distinct (op =) lits)) (add_proof [] proofs)
       
   670 
       
   671   fun solve litss =
       
   672     let
       
   673       val (clss, proofs) = fold_map mk_clause litss (0, Inttab.empty)
       
   674       val vars = fold (fold add_var) litss Inttab.empty
       
   675       val state = (0, [], vars, Inttab.empty, proofs)
       
   676     in uncurry search (fold add_clause clss ([], state)) end
       
   677     handle UNSAT (conflicting_cls, (_, _, vars, _, proofs)) =>
       
   678       let val (idx, (_, ptab)) = mk_proof vars conflicting_cls proofs
       
   679       in SatSolver.UNSATISFIABLE (SOME (ptab, idx)) end
       
   680 
       
   681   fun variable_of (Prop_Logic.BoolVar 0) = error "bad propositional variable"
       
   682     | variable_of (Prop_Logic.BoolVar i) = i
       
   683     | variable_of _ = error "expected formula in CNF"
       
   684   fun literal_of (Prop_Logic.Not fm) = neg (variable_of fm)
       
   685     | literal_of fm = variable_of fm
       
   686   fun clause_of (Prop_Logic.Or (fm1, fm2)) = clause_of fm1 @ clause_of fm2
       
   687     | clause_of fm = [literal_of fm]
       
   688   fun clauses_of (Prop_Logic.And (fm1, fm2)) = clauses_of fm1 @ clauses_of fm2
       
   689     | clauses_of Prop_Logic.True = [[1, ~1]]
       
   690     | clauses_of Prop_Logic.False = [[1], [~1]]
       
   691     | clauses_of fm = [clause_of fm]
       
   692 
       
   693   fun dpll_solver fm =
       
   694     let val fm' = if Prop_Logic.is_cnf fm then fm else Prop_Logic.defcnf fm
       
   695     in solve (clauses_of fm') end
       
   696 in
       
   697   SatSolver.add_solver ("dpll_p", dpll_solver)
   503 end;
   698 end;
   504 
   699 
   505 (* ------------------------------------------------------------------------- *)
   700 (* ------------------------------------------------------------------------- *)
   506 (* Internal SAT solver, available as 'SatSolver.invoke_solver "auto"': uses  *)
   701 (* Internal SAT solver, available as 'SatSolver.invoke_solver "auto"': uses  *)
   507 (* the last installed solver (other than "auto" itself) that does not raise  *)
   702 (* the last installed solver (other than "auto" itself) that does not raise  *)