src/HOL/Library/Sum_Of_Squares/sum_of_squares.ML
changeset 32645 1cc5b24f5a01
parent 32332 bc5cec7b2be6
child 32646 962b4354ed90
--- a/src/HOL/Library/Sum_Of_Squares/sum_of_squares.ML	Tue Sep 22 20:25:31 2009 +0200
+++ b/src/HOL/Library/Sum_Of_Squares/sum_of_squares.ML	Mon Sep 21 15:05:26 2009 +0200
@@ -8,7 +8,12 @@
 signature SOS =
 sig
 
-  val sos_tac : (string -> string) -> Proof.context -> int -> Tactical.tactic
+  datatype proof_method =
+    Certificate of RealArith.pss_tree
+  | Prover of (string -> string)
+
+  val sos_tac : (RealArith.pss_tree -> unit) ->
+    proof_method -> Proof.context -> int -> Tactical.tactic
 
   val debugging : bool ref;
   
@@ -18,6 +23,8 @@
 structure Sos : SOS = 
 struct
 
+open FuncUtil;
+
 val rat_0 = Rat.zero;
 val rat_1 = Rat.one;
 val rat_2 = Rat.two;
@@ -59,6 +66,10 @@
 
 exception Failure of string;
 
+datatype proof_method =
+    Certificate of RealArith.pss_tree
+  | Prover of (string -> string)
+
 (* Turn a rational into a decimal string with d sig digits.                  *)
 
 local
@@ -93,23 +104,11 @@
 
 (* The main types.                                                           *)
 
-fun strict_ord ord (x,y) = case ord (x,y) of LESS => LESS | _ => GREATER
-
-structure Intpairfunc = FuncFun(type key = int*int val ord = prod_ord int_ord int_ord);
-
 type vector = int* Rat.rat Intfunc.T;
 
 type matrix = (int*int)*(Rat.rat Intpairfunc.T);
 
-type monomial = int Ctermfunc.T;
-
-val cterm_ord = (fn (s,t) => TermOrd.fast_term_ord(term_of s, term_of t))
- fun monomial_ord (m1,m2) = list_ord (prod_ord cterm_ord int_ord) (Ctermfunc.graph m1, Ctermfunc.graph m2)
-structure Monomialfunc = FuncFun(type key = monomial val ord = monomial_ord)
-
-type poly = Rat.rat Monomialfunc.T;
-
- fun iszero (k,r) = r =/ rat_0;
+fun iszero (k,r) = r =/ rat_0;
 
 fun fold_rev2 f l1 l2 b =
   case (l1,l2) of
@@ -346,10 +345,7 @@
   sort humanorder_varpow (Ctermfunc.graph m2))
 end;
 
-fun fold1 f l =  case l of
-   []     => error "fold1"
- | [x]    => x
- | (h::t) => f h (fold1 f t);
+fun fold1 f = foldr1 (uncurry f) 
 
 (* Conversions to strings.                                                   *)
 
@@ -404,7 +400,7 @@
  else if c =/ rat_1 then string_of_monomial m
  else Rat.string_of_rat c ^ "*" ^ string_of_monomial m;;
 
-fun string_of_poly (p:poly) =
+fun string_of_poly p =
  if Monomialfunc.is_undefined p then "<<0>>" else
  let 
   val cms = sort (fn ((m1,_),(m2,_)) => humanorder_monomial m1  m2) (Monomialfunc.graph p)
@@ -481,7 +477,6 @@
  in fold1 (fn x => fn y => x ^ " " ^ y) strs ^ "\n"
  end;
 
-fun increasing f ord (x,y) = ord (f x, f y);
 fun triple_int_ord ((a,b,c),(a',b',c')) = 
  prod_ord int_ord (prod_ord int_ord int_ord) 
     ((a,(b,c)),(a',(b',c')));
@@ -1080,11 +1075,6 @@
   fun tryfind f = tryfind_with "tryfind" f
 end
 
-(*
-fun tryfind f [] = error "tryfind"
-  | tryfind f (x::xs) = (f x handle ERROR _ => tryfind f xs);
-*)
-
 (* Positiv- and Nullstellensatz. Flag "linf" forces a linear representation. *)
 
  
@@ -1210,58 +1200,14 @@
 fun deepen f n = 
   (writeln ("Searching with depth limit " ^ string_of_int n) ; (f n handle Failure s => (writeln ("failed with message: " ^ s) ; deepen f (n+1))))
 
-(* The ordering so we can create canonical HOL polynomials.                  *)
 
-fun dest_monomial mon = sort (increasing fst cterm_ord) (Ctermfunc.graph mon);
-
-fun monomial_order (m1,m2) =
- if Ctermfunc.is_undefined m2 then LESS 
- else if Ctermfunc.is_undefined m1 then GREATER 
- else
-  let val mon1 = dest_monomial m1 
-      val mon2 = dest_monomial m2
-      val deg1 = fold (curry op + o snd) mon1 0
-      val deg2 = fold (curry op + o snd) mon2 0 
-  in if deg1 < deg2 then GREATER else if deg1 > deg2 then LESS
-     else list_ord (prod_ord cterm_ord int_ord) (mon1,mon2)
-  end;
-
-fun dest_poly p =
-  map (fn (m,c) => (c,dest_monomial m))
-      (sort (prod_ord monomial_order (K EQUAL)) (Monomialfunc.graph p));
-
-(* Map back polynomials and their composites to HOL.                         *)
+(* Map back polynomials and their composites to a positivstellensatz.        *)
 
 local
  open Thm Numeral RealArith
 in
 
-fun cterm_of_varpow x k = if k = 1 then x else capply (capply @{cterm "op ^ :: real => _"} x) 
-  (mk_cnumber @{ctyp nat} k)
-
-fun cterm_of_monomial m = 
- if Ctermfunc.is_undefined m then @{cterm "1::real"} 
- else 
-  let 
-   val m' = dest_monomial m
-   val vps = fold_rev (fn (x,k) => cons (cterm_of_varpow x k)) m' [] 
-  in fold1 (fn s => fn t => capply (capply @{cterm "op * :: real => _"} s) t) vps
-  end
-
-fun cterm_of_cmonomial (m,c) = if Ctermfunc.is_undefined m then cterm_of_rat c
-    else if c = Rat.one then cterm_of_monomial m
-    else capply (capply @{cterm "op *::real => _"} (cterm_of_rat c)) (cterm_of_monomial m);
-
-fun cterm_of_poly p = 
- if Monomialfunc.is_undefined p then @{cterm "0::real"} 
- else
-  let 
-   val cms = map cterm_of_cmonomial
-     (sort (prod_ord monomial_order (K EQUAL)) (Monomialfunc.graph p))
-  in fold1 (fn t1 => fn t2 => capply(capply @{cterm "op + :: real => _"} t1) t2) cms
-  end;
-
-fun cterm_of_sqterm (c,p) = Product(Rational_lt c,Square(cterm_of_poly p));
+fun cterm_of_sqterm (c,p) = Product(Rational_lt c,Square p);
 
 fun cterm_of_sos (pr,sqs) = if null sqs then pr
   else Product(pr,fold1 (fn a => fn b => Sum(a,b)) (map cterm_of_sqterm sqs));
@@ -1275,14 +1221,14 @@
   fun simple_cterm_ord t u = TermOrd.fast_term_ord (term_of t, term_of u) = LESS
 in
   (* FIXME: Replace tryfind by get_first !! *)
-fun real_nonlinear_prover prover ctxt =
+fun real_nonlinear_prover proof_method ctxt =
  let 
   val {add,mul,neg,pow,sub,main} =  Normalizer.semiring_normalizers_ord_wrapper ctxt
       (valOf (NormalizerData.match ctxt @{cterm "(0::real) + 1"})) 
      simple_cterm_ord
   val (real_poly_add_conv,real_poly_mul_conv,real_poly_neg_conv,
        real_poly_pow_conv,real_poly_sub_conv,real_poly_conv) = (add,mul,neg,pow,sub,main)
-  fun mainf  translator (eqs,les,lts) = 
+  fun mainf cert_choice translator (eqs,les,lts) = 
   let 
    val eq0 = map (poly_of_term o dest_arg1 o concl) eqs
    val le0 = map (poly_of_term o dest_arg o concl) les
@@ -1303,33 +1249,49 @@
                      else raise Failure "trivial_axiom: Not a trivial axiom"
      | _ => error "trivial_axiom: Not a trivial axiom"
    in 
-  ((let val th = tryfind trivial_axiom (keq @ klep @ kltp)
-   in fconv_rule (arg_conv (arg1_conv real_poly_conv) then_conv field_comp_conv) th end)
-   handle Failure _ => (
-    let 
-     val pol = fold_rev poly_mul (map fst ltp) (poly_const Rat.one)
-     val leq = lep @ ltp
-     fun tryall d =
-      let val e = multidegree pol
-          val k = if e = 0 then 0 else d div e
-          val eq' = map fst eq 
-      in tryfind (fn i => (d,i,real_positivnullstellensatz_general prover false d eq' leq
-                            (poly_neg(poly_pow pol i))))
-              (0 upto k)
-      end
-    val (d,i,(cert_ideal,cert_cone)) = deepen tryall 0
-    val proofs_ideal =
-      map2 (fn q => fn (p,ax) => Eqmul(cterm_of_poly q,ax)) cert_ideal eq
-    val proofs_cone = map cterm_of_sos cert_cone
-    val proof_ne = if null ltp then Rational_lt Rat.one else
-      let val p = fold1 (fn s => fn t => Product(s,t)) (map snd ltp) 
-      in  funpow i (fn q => Product(p,q)) (Rational_lt Rat.one)
-      end
-    val proof = fold1 (fn s => fn t => Sum(s,t))
-                           (proof_ne :: proofs_ideal @ proofs_cone) 
-    in writeln "Translating proof certificate to HOL";
-       translator (eqs,les,lts) proof
-    end))
+  (let val th = tryfind trivial_axiom (keq @ klep @ kltp)
+   in
+    (fconv_rule (arg_conv (arg1_conv real_poly_conv) then_conv field_comp_conv) th, Trivial)
+   end)
+   handle Failure _ => 
+     (let val proof =
+       (case proof_method of Certificate certs =>
+         (* choose certificate *)
+         let
+           fun chose_cert [] (Cert c) = c
+             | chose_cert (Left::s) (Branch (l, _)) = chose_cert s l
+             | chose_cert (Right::s) (Branch (_, r)) = chose_cert s r
+             | chose_cert _ _ = error "certificate tree in invalid form"
+         in
+           chose_cert cert_choice certs
+         end
+       | Prover prover =>
+         (* call prover *)
+         let 
+          val pol = fold_rev poly_mul (map fst ltp) (poly_const Rat.one)
+          val leq = lep @ ltp
+          fun tryall d =
+           let val e = multidegree pol
+               val k = if e = 0 then 0 else d div e
+               val eq' = map fst eq 
+           in tryfind (fn i => (d,i,real_positivnullstellensatz_general prover false d eq' leq
+                                 (poly_neg(poly_pow pol i))))
+                   (0 upto k)
+           end
+         val (d,i,(cert_ideal,cert_cone)) = deepen tryall 0
+         val proofs_ideal =
+           map2 (fn q => fn (p,ax) => Eqmul(q,ax)) cert_ideal eq
+         val proofs_cone = map cterm_of_sos cert_cone
+         val proof_ne = if null ltp then Rational_lt Rat.one else
+           let val p = fold1 (fn s => fn t => Product(s,t)) (map snd ltp) 
+           in  funpow i (fn q => Product(p,q)) (Rational_lt Rat.one)
+           end
+         in 
+           fold1 (fn s => fn t => Sum(s,t)) (proof_ne :: proofs_ideal @ proofs_cone) 
+         end)
+     in
+        (translator (eqs,les,lts) proof, Cert proof)
+     end)
    end
  in mainf end
 end
@@ -1396,7 +1358,7 @@
          orelse g aconvc @{cterm "op < :: real => _"} 
        then arg_conv cv ct else arg1_conv cv ct
     end
-  fun mainf translator =
+  fun mainf cert_choice translator =
    let 
     fun substfirst(eqs,les,lts) =
       ((let 
@@ -1407,7 +1369,7 @@
                                    aconvc @{cterm "0::real"}) (map modify eqs),
                                    map modify les,map modify lts)
        end)
-       handle Failure  _ => real_nonlinear_prover prover ctxt translator (rev eqs, rev les, rev lts))
+       handle Failure  _ => real_nonlinear_prover prover ctxt cert_choice translator (rev eqs, rev les, rev lts))
     in substfirst
    end
 
@@ -1417,7 +1379,8 @@
 
 (* Overall function. *)
 
-fun real_sos prover ctxt t = gen_prover_real_arith ctxt (real_nonlinear_subst_prover prover ctxt) t;
+fun real_sos prover ctxt =
+  gen_prover_real_arith ctxt (real_nonlinear_subst_prover prover ctxt)
 end;
 
 (* A tactic *)
@@ -1429,8 +1392,6 @@
    end
 | _ => ([],ct)
 
-fun core_sos_conv prover ctxt t = Drule.arg_cong_rule @{cterm Trueprop} (real_sos prover ctxt (Thm.dest_arg t) RS @{thm Eq_TrueI})
-
 val known_sos_constants = 
   [@{term "op ==>"}, @{term "Trueprop"}, 
    @{term "op -->"}, @{term "op &"}, @{term "op |"}, 
@@ -1458,17 +1419,19 @@
   val _ = if exists (fn ((_,T)) => not (T = @{typ "real"})) fs 
           then error "SOS: not sos. Variables with type not real" else ()
   val vs = Term.add_vars t []
-  val _ = if exists (fn ((_,T)) => not (T = @{typ "real"})) fs 
+  val _ = if exists (fn ((_,T)) => not (T = @{typ "real"})) vs 
           then error "SOS: not sos. Variables with type not real" else ()
   val ukcs = subtract (fn (t,p) => Const p aconv t) kcts (Term.add_consts t [])
   val _ = if  null ukcs then () 
               else error ("SOSO: Unknown constants in Subgoal:" ^ commas (map fst ukcs))
 in () end
 
-fun core_sos_tac prover ctxt = CSUBGOAL (fn (ct, i) => 
+fun core_sos_tac print_certs prover ctxt = CSUBGOAL (fn (ct, i) => 
   let val _ = check_sos known_sos_constants ct
       val (avs, p) = strip_all ct
-      val th = standard (fold_rev forall_intr avs (real_sos prover ctxt (Thm.dest_arg p)))
+      val (ths, certificates) = real_sos prover ctxt (Thm.dest_arg p)
+      val th = standard (fold_rev forall_intr avs ths)
+      val _ = print_certs certificates
   in rtac th i end);
 
 fun default_SOME f NONE v = SOME v
@@ -1506,7 +1469,7 @@
 
 fun elim_denom_tac ctxt i = REPEAT (elim_one_denom_tac ctxt i);
 
-fun sos_tac prover ctxt = ObjectLogic.full_atomize_tac THEN' elim_denom_tac ctxt THEN' core_sos_tac prover ctxt
+fun sos_tac print_certs prover ctxt = ObjectLogic.full_atomize_tac THEN' elim_denom_tac ctxt THEN' core_sos_tac print_certs prover ctxt
 
 
 end;