src/HOL/Tools/sat_funcs.ML
changeset 20278 28be10991666
parent 20170 6ff853f82d73
child 20371 a0f8e89d369d
--- a/src/HOL/Tools/sat_funcs.ML	Tue Aug 01 15:28:55 2006 +0200
+++ b/src/HOL/Tools/sat_funcs.ML	Wed Aug 02 00:57:41 2006 +0200
@@ -76,6 +76,23 @@
 		Thm.instantiate ([], [(cterm Q, cterm False)]) case_split_thm
 	end;
 
+(* Thm.cterm *)
+val cP = cterm_of (theory_of_thm resolution_thm) (Var (("P", 0), HOLogic.boolT));
+
+(* ------------------------------------------------------------------------- *)
+(* CLAUSE: during proof reconstruction, three kinds of clauses are           *)
+(*      distinguished:                                                       *)
+(*      1. NO_CLAUSE: clause not proved (yet)                                *)
+(*      2. ORIG_CLAUSE: a clause as it occurs in the original problem        *)
+(*      3. RAW_CLAUSE: a raw clause, with additional precomputed information *)
+(*         (a mapping from int's to its literals) for faster proof           *)
+(*         reconstruction                                                    *)
+(* ------------------------------------------------------------------------- *)
+
+datatype CLAUSE = NO_CLAUSE
+                | ORIG_CLAUSE of Thm.thm
+                | RAW_CLAUSE of Thm.thm * Thm.cterm Inttab.table;
+
 (* ------------------------------------------------------------------------- *)
 (* resolve_raw_clauses: given a non-empty list of raw clauses, we fold       *)
 (*      resolution over the list (starting with its head), i.e. with two raw *)
@@ -92,78 +109,47 @@
 (*        [| ?P ==> False; ~?P ==> False |] ==> False                        *)
 (*      to produce                                                           *)
 (*        x1; ...; xn; y1; ...; ym |- False                                  *)
+(*      Each clause is accompanied with a table mapping integers (positive   *)
+(*      for positive literals, negative for negative literals, and the same  *)
+(*      absolute value for dual literals) to the actual literals as cterms.  *)
 (* ------------------------------------------------------------------------- *)
 
-(* Thm.thm list -> Thm.thm *)
+(* (Thm.thm * Thm.cterm Inttab.table) list -> Thm.thm * Thm.cterm Inttab.table *)
 
 fun resolve_raw_clauses [] =
 	raise THM ("Proof reconstruction failed (empty list of resolvents)!", 0, [])
   | resolve_raw_clauses (c::cs) =
 	let
-		fun dual (Const ("Not", _) $ x) = x
-		  | dual x                      = HOLogic.Not $ x
-
-		fun is_neg (Const ("Not", _) $ _) = true
-		  | is_neg _                      = false
-
-		(* see the comments on the term order below for why this implementation is sound *)
-		(* (Term.term * Term.term -> order) -> Thm.cterm list -> Term.term -> Thm.cterm option *)
-		fun member' _   []      _ = NONE
-		  | member' ord (y::ys) x = (case term_of y of  (* "un-certifying" y is faster than certifying x *)
-			  Const ("Trueprop", _) $ y' =>
-				(* compare the order *)
-				(case ord (x, y') of
-				  LESS    => NONE
-				| EQUAL   => SOME y
-				| GREATER => member' ord ys x)
-			| _                         =>
-				(* no need to continue in this case *)
-				NONE)
+		(* find out which two hyps are used in the resolution *)
+		local exception RESULT of int * Thm.cterm * Thm.cterm in
+			(* Thm.cterm Inttab.table -> Thm.cterm Inttab.table -> int * Thm.cterm * Thm.cterm *)
+			fun find_res_hyps hyp1_table hyp2_table = (
+				Inttab.fold (fn (i, hyp1) => fn () =>
+					case Inttab.lookup hyp2_table (~i) of
+					  SOME hyp2 => raise RESULT (i, hyp1, hyp2)
+					| NONE      => ()) hyp1_table ();
+				raise THM ("Proof reconstruction failed (no literal for resolution)!", 0, [])
+			) handle RESULT x => x
+		end
 
-		(* find out which two hyps are used in the resolution *)
-		(* Thm.cterm list -> Thm.cterm list -> Thm.cterm * Thm.cterm *)
-		fun res_hyps [] _ =
-			raise THM ("Proof reconstruction failed (no literal for resolution)!", 0, [])
-		  | res_hyps _ [] =
-			raise THM ("Proof reconstruction failed (no literal for resolution)!", 0, [])
-		  | res_hyps (x :: xs) ys =
-			(case term_of x of
-			  Const ("Trueprop", _) $ lit =>
-				(* hyps are implemented as ordered list in the kernel, and *)
-				(* stripping 'Trueprop' should not change the order        *)
-				(case member' Term.fast_term_ord ys (dual lit) of
-				  SOME y => (x, y)
-				| NONE   => res_hyps xs ys)
-			| _ =>
-				(* hyps are implemented as ordered list in the kernel, all hyps are of *)
-				(* the form 'Trueprop $ lit' or 'implies $ (negated clause) $ False',  *)
-				(* and the former are LESS than the latter according to the order --   *)
-				(* therefore there is no need to continue the search via               *)
-				(* 'res_hyps xs ys' here                                               *)
-				raise THM ("Proof reconstruction failed (no literal for resolution)!", 0, []))
-
-		(* Thm.thm -> Thm.thm -> Thm.thm *)
-		fun resolution c1 c2 =
+		(* Thm.thm * Thm.cterm Inttab.table -> Thm.thm * Thm.cterm Inttab.table -> Thm.thm * Thm.cterm Inttab.table *)
+		fun resolution (c1, hyp1_table) (c2, hyp2_table) =
 		let
 			val _ = if !trace_sat then
 					tracing ("Resolving clause: " ^ string_of_thm c1 ^ " (hyps: " ^ space_implode ", " (map (Sign.string_of_term (theory_of_thm c1)) (#hyps (rep_thm c1)))
 						^ ")\nwith clause: " ^ string_of_thm c2 ^ " (hyps: " ^ space_implode ", " (map (Sign.string_of_term (theory_of_thm c2)) (#hyps (rep_thm c2))) ^ ")")
 				else ()
 
-			val hyps1     = (#hyps o crep_thm) c1
-			val hyps2     = (#hyps o crep_thm) c2
-
-			val (l1, l2)  = res_hyps hyps1 hyps2  (* the two literals used for resolution *)
-			val l1_is_neg = (is_neg o HOLogic.dest_Trueprop o term_of) l1
+			(* the two literals used for resolution *)
+			val (i1, hyp1, hyp2) = find_res_hyps hyp1_table hyp2_table
+			val hyp1_is_neg      = i1 < 0
 
-			val c1'       = Thm.implies_intr l1 c1  (* Gamma1 |- l1 ==> False *)
-			val c2'       = Thm.implies_intr l2 c2  (* Gamma2 |- l2 ==> False *)
+			val c1' = Thm.implies_intr hyp1 c1  (* Gamma1 |- hyp1 ==> False *)
+			val c2' = Thm.implies_intr hyp2 c2  (* Gamma2 |- hyp2 ==> False *)
 
-			val res_thm   =  (* |- (lit ==> False) ==> (~lit ==> False) ==> False *)
+			val res_thm =  (* |- (lit ==> False) ==> (~lit ==> False) ==> False *)
 				let
-					val thy  = theory_of_thm (if l1_is_neg then c2' else c1')
-					val cP   = cterm_of thy (Var (("P", 0), HOLogic.boolT))
-					val cLit = snd (Thm.dest_comb (if l1_is_neg then l2 else l1))  (* strip Trueprop *)
+					val cLit = snd (Thm.dest_comb (if hyp1_is_neg then hyp2 else hyp1))  (* strip Trueprop *)
 				in
 					Thm.instantiate ([], [(cP, cLit)]) resolution_thm
 				end
@@ -172,14 +158,19 @@
 					tracing ("Resolution theorem: " ^ string_of_thm res_thm)
 				else ()
 
-			val c_new     = Thm.implies_elim (Thm.implies_elim res_thm (if l1_is_neg then c2' else c1')) (if l1_is_neg then c1' else c2')  (* Gamma1, Gamma2 |- False *)
+			val c_new = Thm.implies_elim (Thm.implies_elim res_thm (if hyp1_is_neg then c2' else c1')) (if hyp1_is_neg then c1' else c2')  (* Gamma1, Gamma2 |- False *)
+
+			(* since the mapping from integers to literals should be injective *)
+			(* (over different clauses), 'K true' here should be equivalent to *)
+			(* 'op=' (but faster)                                              *)
+			val hypnew_table = Inttab.merge (K true) (Inttab.delete i1 hyp1_table, Inttab.delete (~i1) hyp2_table)
 
 			val _ = if !trace_sat then
 					tracing ("Resulting clause: " ^ string_of_thm c_new ^ " (hyps: " ^ space_implode ", " (map (Sign.string_of_term (theory_of_thm c_new)) (#hyps (rep_thm c_new))) ^ ")")
 				else ()
 			val _ = inc counter
 		in
-			c_new
+			(c_new, hypnew_table)
 		end
 	in
 		fold resolution cs c
@@ -191,36 +182,67 @@
 (*      'clauses' array with derived clauses, and returns the derived clause *)
 (*      at index 'empty_id' (which should just be "False" if proof           *)
 (*      reconstruction was successful, with the used clauses as hyps).       *)
+(*      'atom_table' must contain an injective mapping from all atoms that   *)
+(*      occur (as part of a literal) in 'clauses' to positive integers.      *)
 (* ------------------------------------------------------------------------- *)
 
-(* Thm.thm option Array.array -> SatSolver.proof -> Thm.thm *)
+(* int Termtab.table -> CLAUSE Array.array -> SatSolver.proof -> Thm.thm *)
 
-fun replay_proof clauses (clause_table, empty_id) =
+fun replay_proof atom_table clauses (clause_table, empty_id) =
 let
-	(* int -> Thm.thm *)
+	(* Thm.cterm -> int option *)
+	fun index_of_literal chyp = (
+		case (HOLogic.dest_Trueprop o term_of) chyp of
+		  (Const ("Not", _) $ atom) =>
+			SOME (~(valOf (Termtab.lookup atom_table atom)))
+		| atom =>
+			SOME (valOf (Termtab.lookup atom_table atom))
+	) handle TERM _ => NONE;  (* 'chyp' is not a literal *)
+
+	(* int -> Thm.thm * Thm.cterm Inttab.table *)
 	fun prove_clause id =
 		case Array.sub (clauses, id) of
-		  SOME thm =>
-			thm
-		| NONE     =>
+		  RAW_CLAUSE clause =>
+			clause
+		| ORIG_CLAUSE thm =>
+			(* convert the original clause *)
 			let
-				val _   = if !trace_sat then tracing ("Proving clause #" ^ string_of_int id ^ " ...") else ()
-				val ids = valOf (Inttab.lookup clause_table id)
-				val thm = resolve_raw_clauses (map prove_clause ids)
-				val _   = Array.update (clauses, id, SOME thm)
-				val _   = if !trace_sat then tracing ("Replay chain successful; clause stored at #" ^ string_of_int id) else ()
+				val _         = if !trace_sat then tracing ("Using original clause #" ^ string_of_int id) else ()
+				val raw       = cnf.clause2raw_thm thm
+				val lit_table = fold (fn chyp => fn lit_table => (case index_of_literal chyp of
+					  SOME i => Inttab.update_new (i, chyp) lit_table
+					| NONE   => lit_table)) (#hyps (Thm.crep_thm raw)) Inttab.empty
+				val clause    = (raw, lit_table)
+				val _         = Array.update (clauses, id, RAW_CLAUSE clause)
 			in
-				thm
+				clause
+			end
+		| NO_CLAUSE =>
+			(* prove the clause, using information from 'clause_table' *)
+			let
+				val _      = if !trace_sat then tracing ("Proving clause #" ^ string_of_int id ^ " ...") else ()
+				val ids    = valOf (Inttab.lookup clause_table id)
+				val clause = resolve_raw_clauses (map prove_clause ids)
+				val _      = Array.update (clauses, id, RAW_CLAUSE clause)
+				val _      = if !trace_sat then tracing ("Replay chain successful; clause stored at #" ^ string_of_int id) else ()
+			in
+				clause
 			end
 
 	val _            = counter := 0
-	val empty_clause = prove_clause empty_id
+	val empty_clause = fst (prove_clause empty_id)
 	val _            = if !trace_sat then tracing ("Proof reconstruction successful; " ^ string_of_int (!counter) ^ " resolution step(s) total.") else ()
 in
 	empty_clause
 end;
 
+(* ------------------------------------------------------------------------- *)
+(* string_of_prop_formula: return a human-readable string representation of  *)
+(*      a 'prop_formula' (just for tracing)                                  *)
+(* ------------------------------------------------------------------------- *)
+
 (* PropLogic.prop_formula -> string *)
+
 fun string_of_prop_formula PropLogic.True             = "True"
   | string_of_prop_formula PropLogic.False            = "False"
   | string_of_prop_formula (PropLogic.BoolVar i)      = "x" ^ string_of_int i
@@ -286,18 +308,12 @@
 			make_quick_and_dirty_thm ()
 		else
 			let
-				(* initialize the clause array with the given clauses, *)
-				(* but converted to raw clause format                  *)
+				(* initialize the clause array with the given clauses *)
 				val max_idx     = valOf (Inttab.max_key clause_table)
-				val clause_arr  = Array.array (max_idx + 1, NONE)
-				val raw_clauses = map cnf.clause2raw_thm non_triv_clauses
-				(* Every raw clause has only its literals and itself as hyp, and hyps are *)
-				(* accumulated during resolution steps.  Experimental results indicate    *)
-				(* that it is NOT faster to weaken all raw_clauses to contain every       *)
-				(* clause in the hyps beforehand.                                         *)
-				val _           = fold (fn thm => fn idx => (Array.update (clause_arr, idx, SOME thm); idx+1)) raw_clauses 0
+				val clause_arr  = Array.array (max_idx + 1, NO_CLAUSE)
+				val _           = fold (fn thm => fn idx => (Array.update (clause_arr, idx, ORIG_CLAUSE thm); idx+1)) non_triv_clauses 0
 				(* replay the proof to derive the empty clause *)
-				val FalseThm    = replay_proof clause_arr (clause_table, empty_id)
+				val FalseThm    = replay_proof atom_table clause_arr (clause_table, empty_id)
 			in
 				(* convert the hyps back to the original format *)
 				cnf.rawhyps2clausehyps_thm FalseThm
@@ -353,8 +369,9 @@
 
 (* int -> Tactical.tactic *)
 
-fun pre_cnf_tac i = rtac ccontr i THEN ObjectLogic.atomize_tac i THEN
-                      PRIMITIVE (Drule.fconv_rule (Drule.goals_conv (equal i) (Drule.beta_eta_conversion)));
+fun pre_cnf_tac i =
+	rtac ccontr i THEN ObjectLogic.atomize_tac i THEN
+		PRIMITIVE (Drule.fconv_rule (Drule.goals_conv (equal i) (Drule.beta_eta_conversion)));
 
 (* ------------------------------------------------------------------------- *)
 (* cnfsat_tac: checks if the empty clause "False" occurs among the premises; *)
@@ -365,7 +382,8 @@
 
 (* int -> Tactical.tactic *)
 
-fun cnfsat_tac i = (etac FalseE i) ORELSE (REPEAT_DETERM (etac conjE i) THEN rawsat_tac i);
+fun cnfsat_tac i =
+	(etac FalseE i) ORELSE (REPEAT_DETERM (etac conjE i) THEN rawsat_tac i);
 
 (* ------------------------------------------------------------------------- *)
 (* cnfxsat_tac: checks if the empty clause "False" occurs among the          *)
@@ -376,7 +394,9 @@
 
 (* int -> Tactical.tactic *)
 
-fun cnfxsat_tac i = (etac FalseE i) ORELSE (REPEAT_DETERM (etac conjE i ORELSE etac exE i) THEN rawsat_tac i);
+fun cnfxsat_tac i =
+	(etac FalseE i) ORELSE
+		(REPEAT_DETERM (etac conjE i ORELSE etac exE i) THEN rawsat_tac i);
 
 (* ------------------------------------------------------------------------- *)
 (* sat_tac: tactic for calling an external SAT solver, taking as input an    *)
@@ -386,7 +406,8 @@
 
 (* int -> Tactical.tactic *)
 
-fun sat_tac i = pre_cnf_tac i THEN cnf.cnf_rewrite_tac i THEN cnfsat_tac i;
+fun sat_tac i =
+	pre_cnf_tac i THEN cnf.cnf_rewrite_tac i THEN cnfsat_tac i;
 
 (* ------------------------------------------------------------------------- *)
 (* satx_tac: tactic for calling an external SAT solver, taking as input an   *)
@@ -396,6 +417,7 @@
 
 (* int -> Tactical.tactic *)
 
-fun satx_tac i = pre_cnf_tac i THEN cnf.cnfx_rewrite_tac i THEN cnfxsat_tac i;
+fun satx_tac i =
+	pre_cnf_tac i THEN cnf.cnfx_rewrite_tac i THEN cnfxsat_tac i;
 
 end;  (* of structure SATFunc *)