src/HOL/Multivariate_Analysis/normarith.ML
author wenzelm
Fri May 13 22:55:00 2011 +0200 (2011-05-13)
changeset 42793 88bee9f6eec7
parent 42361 23f352990944
child 43333 2bdec7f430d3
permissions -rw-r--r--
proper Proof.context for classical tactics;
reduced claset to snapshot of classical context;
discontinued clasimpset;
     1 (*  Title:      HOL/Multivariate_Analysis/normarith.ML
     2     Author:     Amine Chaieb, University of Cambridge
     3 
     4 Simple decision procedure for linear problems in Euclidean space.
     5 *)
     6 
     7 signature NORM_ARITH =
     8 sig
     9  val norm_arith : Proof.context -> conv
    10  val norm_arith_tac : Proof.context -> int -> tactic
    11 end
    12 
    13 structure NormArith : NORM_ARITH =
    14 struct
    15 
    16  open Conv;
    17  val bool_eq = op = : bool *bool -> bool
    18   fun dest_ratconst t = case term_of t of
    19    Const(@{const_name divide}, _)$a$b => Rat.rat_of_quotient(HOLogic.dest_number a |> snd, HOLogic.dest_number b |> snd)
    20  | Const(@{const_name inverse}, _)$a => Rat.rat_of_quotient(1, HOLogic.dest_number a |> snd)
    21  | _ => Rat.rat_of_int (HOLogic.dest_number (term_of t) |> snd)
    22  fun is_ratconst t = can dest_ratconst t
    23  fun augment_norm b t acc = case term_of t of
    24      Const(@{const_name norm}, _) $ _ => insert (eq_pair bool_eq (op aconvc)) (b,Thm.dest_arg t) acc
    25    | _ => acc
    26  fun find_normedterms t acc = case term_of t of
    27     @{term "op + :: real => _"}$_$_ =>
    28             find_normedterms (Thm.dest_arg1 t) (find_normedterms (Thm.dest_arg t) acc)
    29       | @{term "op * :: real => _"}$_$n =>
    30             if not (is_ratconst (Thm.dest_arg1 t)) then acc else
    31             augment_norm (dest_ratconst (Thm.dest_arg1 t) >=/ Rat.zero)
    32                       (Thm.dest_arg t) acc
    33       | _ => augment_norm true t acc
    34 
    35  val cterm_lincomb_neg = FuncUtil.Ctermfunc.map (K Rat.neg)
    36  fun cterm_lincomb_cmul c t =
    37     if c =/ Rat.zero then FuncUtil.Ctermfunc.empty else FuncUtil.Ctermfunc.map (fn _ => fn x => x */ c) t
    38  fun cterm_lincomb_add l r = FuncUtil.Ctermfunc.combine (curry op +/) (fn x => x =/ Rat.zero) l r
    39  fun cterm_lincomb_sub l r = cterm_lincomb_add l (cterm_lincomb_neg r)
    40  fun cterm_lincomb_eq l r = FuncUtil.Ctermfunc.is_empty (cterm_lincomb_sub l r)
    41 
    42  val int_lincomb_neg = FuncUtil.Intfunc.map (K Rat.neg)
    43  fun int_lincomb_cmul c t =
    44     if c =/ Rat.zero then FuncUtil.Intfunc.empty else FuncUtil.Intfunc.map (fn _ => fn x => x */ c) t
    45  fun int_lincomb_add l r = FuncUtil.Intfunc.combine (curry op +/) (fn x => x =/ Rat.zero) l r
    46  fun int_lincomb_sub l r = int_lincomb_add l (int_lincomb_neg r)
    47  fun int_lincomb_eq l r = FuncUtil.Intfunc.is_empty (int_lincomb_sub l r)
    48 
    49 fun vector_lincomb t = case term_of t of
    50    Const(@{const_name plus}, _) $ _ $ _ =>
    51     cterm_lincomb_add (vector_lincomb (Thm.dest_arg1 t)) (vector_lincomb (Thm.dest_arg t))
    52  | Const(@{const_name minus}, _) $ _ $ _ =>
    53     cterm_lincomb_sub (vector_lincomb (Thm.dest_arg1 t)) (vector_lincomb (Thm.dest_arg t))
    54  | Const(@{const_name scaleR}, _)$_$_ =>
    55     cterm_lincomb_cmul (dest_ratconst (Thm.dest_arg1 t)) (vector_lincomb (Thm.dest_arg t))
    56  | Const(@{const_name uminus}, _)$_ =>
    57      cterm_lincomb_neg (vector_lincomb (Thm.dest_arg t))
    58 (* FIXME: how should we handle numerals?
    59  | Const(@ {const_name vec},_)$_ =>
    60    let
    61      val b = ((snd o HOLogic.dest_number o term_of o Thm.dest_arg) t = 0
    62                handle TERM _=> false)
    63    in if b then FuncUtil.Ctermfunc.onefunc (t,Rat.one)
    64       else FuncUtil.Ctermfunc.empty
    65    end
    66 *)
    67  | _ => FuncUtil.Ctermfunc.onefunc (t,Rat.one)
    68 
    69  fun vector_lincombs ts =
    70   fold_rev
    71    (fn t => fn fns => case AList.lookup (op aconvc) fns t of
    72      NONE =>
    73        let val f = vector_lincomb t
    74        in case find_first (fn (_,f') => cterm_lincomb_eq f f') fns of
    75            SOME (_,f') => (t,f') :: fns
    76          | NONE => (t,f) :: fns
    77        end
    78    | SOME _ => fns) ts []
    79 
    80 fun replacenegnorms cv t = case term_of t of
    81   @{term "op + :: real => _"}$_$_ => binop_conv (replacenegnorms cv) t
    82 | @{term "op * :: real => _"}$_$_ =>
    83     if dest_ratconst (Thm.dest_arg1 t) </ Rat.zero then arg_conv cv t else Thm.reflexive t
    84 | _ => Thm.reflexive t
    85 fun flip v eq =
    86   if FuncUtil.Ctermfunc.defined eq v
    87   then FuncUtil.Ctermfunc.update (v, Rat.neg (FuncUtil.Ctermfunc.apply eq v)) eq else eq
    88 fun allsubsets s = case s of
    89   [] => [[]]
    90 |(a::t) => let val res = allsubsets t in
    91                map (cons a) res @ res end
    92 fun evaluate env lin =
    93  FuncUtil.Intfunc.fold (fn (x,c) => fn s => s +/ c */ (FuncUtil.Intfunc.apply env x))
    94    lin Rat.zero
    95 
    96 fun solve (vs,eqs) = case (vs,eqs) of
    97   ([],[]) => SOME (FuncUtil.Intfunc.onefunc (0,Rat.one))
    98  |(_,eq::oeqs) =>
    99    (case filter (member (op =) vs) (FuncUtil.Intfunc.dom eq) of (*FIXME use find_first here*)
   100      [] => NONE
   101     | v::_ =>
   102        if FuncUtil.Intfunc.defined eq v
   103        then
   104         let
   105          val c = FuncUtil.Intfunc.apply eq v
   106          val vdef = int_lincomb_cmul (Rat.neg (Rat.inv c)) eq
   107          fun eliminate eqn = if not (FuncUtil.Intfunc.defined eqn v) then eqn
   108                              else int_lincomb_add (int_lincomb_cmul (FuncUtil.Intfunc.apply eqn v) vdef) eqn
   109         in (case solve (remove (op =) v vs, map eliminate oeqs) of
   110             NONE => NONE
   111           | SOME soln => SOME (FuncUtil.Intfunc.update (v, evaluate soln (FuncUtil.Intfunc.delete_safe v vdef)) soln))
   112         end
   113        else NONE)
   114 
   115 fun combinations k l = if k = 0 then [[]] else
   116  case l of
   117   [] => []
   118 | h::t => map (cons h) (combinations (k - 1) t) @ combinations k t
   119 
   120 fun vertices vs eqs =
   121  let
   122   fun vertex cmb = case solve(vs,cmb) of
   123     NONE => NONE
   124    | SOME soln => SOME (map (fn v => FuncUtil.Intfunc.tryapplyd soln v Rat.zero) vs)
   125   val rawvs = map_filter vertex (combinations (length vs) eqs)
   126   val unset = filter (forall (fn c => c >=/ Rat.zero)) rawvs
   127  in fold_rev (insert (eq_list op =/)) unset []
   128  end
   129 
   130 val subsumes = eq_list (fn (x, y) => Rat.abs x <=/ Rat.abs y)
   131 
   132 fun subsume todo dun = case todo of
   133  [] => dun
   134 |v::ovs =>
   135    let val dun' = if exists (fn w => subsumes (w, v)) dun then dun
   136                   else v:: filter (fn w => not (subsumes (v, w))) dun
   137    in subsume ovs dun'
   138    end;
   139 
   140 fun match_mp PQ P = P RS PQ;
   141 
   142 fun cterm_of_rat x =
   143 let val (a, b) = Rat.quotient_of_rat x
   144 in
   145  if b = 1 then Numeral.mk_cnumber @{ctyp "real"} a
   146   else Thm.capply (Thm.capply @{cterm "op / :: real => _"}
   147                    (Numeral.mk_cnumber @{ctyp "real"} a))
   148         (Numeral.mk_cnumber @{ctyp "real"} b)
   149 end;
   150 
   151 fun norm_cmul_rule c th = instantiate' [] [SOME (cterm_of_rat c)] (th RS @{thm norm_cmul_rule_thm});
   152 
   153 fun norm_add_rule th1 th2 = [th1, th2] MRS @{thm norm_add_rule_thm};
   154 
   155   (* I think here the static context should be sufficient!! *)
   156 fun inequality_canon_rule ctxt =
   157  let
   158   (* FIXME : Should be computed statically!! *)
   159   val real_poly_conv =
   160     Semiring_Normalizer.semiring_normalize_wrapper ctxt
   161      (the (Semiring_Normalizer.match ctxt @{cterm "(0::real) + 1"}))
   162  in fconv_rule (arg_conv ((rewr_conv @{thm ge_iff_diff_ge_0}) then_conv arg_conv (Numeral_Simprocs.field_comp_conv then_conv real_poly_conv)))
   163 end;
   164 
   165  val apply_pth1 = rewr_conv @{thm pth_1};
   166  val apply_pth2 = rewr_conv @{thm pth_2};
   167  val apply_pth3 = rewr_conv @{thm pth_3};
   168  val apply_pth4 = rewrs_conv @{thms pth_4};
   169  val apply_pth5 = rewr_conv @{thm pth_5};
   170  val apply_pth6 = rewr_conv @{thm pth_6};
   171  val apply_pth7 = rewrs_conv @{thms pth_7};
   172  val apply_pth8 = rewr_conv @{thm pth_8} then_conv arg1_conv Numeral_Simprocs.field_comp_conv then_conv (try_conv (rewr_conv (mk_meta_eq @{thm scaleR_zero_left})));
   173  val apply_pth9 = rewrs_conv @{thms pth_9} then_conv arg1_conv (arg1_conv Numeral_Simprocs.field_comp_conv);
   174  val apply_ptha = rewr_conv @{thm pth_a};
   175  val apply_pthb = rewrs_conv @{thms pth_b};
   176  val apply_pthc = rewrs_conv @{thms pth_c};
   177  val apply_pthd = try_conv (rewr_conv @{thm pth_d});
   178 
   179 fun headvector t = case t of
   180   Const(@{const_name plus}, _)$
   181    (Const(@{const_name scaleR}, _)$l$v)$r => v
   182  | Const(@{const_name scaleR}, _)$l$v => v
   183  | _ => error "headvector: non-canonical term"
   184 
   185 fun vector_cmul_conv ct =
   186    ((apply_pth5 then_conv arg1_conv Numeral_Simprocs.field_comp_conv) else_conv
   187     (apply_pth6 then_conv binop_conv vector_cmul_conv)) ct
   188 
   189 fun vector_add_conv ct = apply_pth7 ct
   190  handle CTERM _ =>
   191   (apply_pth8 ct
   192    handle CTERM _ =>
   193     (case term_of ct of
   194      Const(@{const_name plus},_)$lt$rt =>
   195       let
   196        val l = headvector lt
   197        val r = headvector rt
   198       in (case Term_Ord.fast_term_ord (l,r) of
   199          LESS => (apply_pthb then_conv arg_conv vector_add_conv
   200                   then_conv apply_pthd) ct
   201         | GREATER => (apply_pthc then_conv arg_conv vector_add_conv
   202                      then_conv apply_pthd) ct
   203         | EQUAL => (apply_pth9 then_conv
   204                 ((apply_ptha then_conv vector_add_conv) else_conv
   205               arg_conv vector_add_conv then_conv apply_pthd)) ct)
   206       end
   207      | _ => Thm.reflexive ct))
   208 
   209 fun vector_canon_conv ct = case term_of ct of
   210  Const(@{const_name plus},_)$_$_ =>
   211   let
   212    val ((p,l),r) = Thm.dest_comb ct |>> Thm.dest_comb
   213    val lth = vector_canon_conv l
   214    val rth = vector_canon_conv r
   215    val th = Drule.binop_cong_rule p lth rth
   216   in fconv_rule (arg_conv vector_add_conv) th end
   217 
   218 | Const(@{const_name scaleR}, _)$_$_ =>
   219   let
   220    val (p,r) = Thm.dest_comb ct
   221    val rth = Drule.arg_cong_rule p (vector_canon_conv r)
   222   in fconv_rule (arg_conv (apply_pth4 else_conv vector_cmul_conv)) rth
   223   end
   224 
   225 | Const(@{const_name minus},_)$_$_ => (apply_pth2 then_conv vector_canon_conv) ct
   226 
   227 | Const(@{const_name uminus},_)$_ => (apply_pth3 then_conv vector_canon_conv) ct
   228 
   229 (* FIXME
   230 | Const(@{const_name vec},_)$n =>
   231   let val n = Thm.dest_arg ct
   232   in if is_ratconst n andalso not (dest_ratconst n =/ Rat.zero)
   233      then Thm.reflexive ct else apply_pth1 ct
   234   end
   235 *)
   236 | _ => apply_pth1 ct
   237 
   238 fun norm_canon_conv ct = case term_of ct of
   239   Const(@{const_name norm},_)$_ => arg_conv vector_canon_conv ct
   240  | _ => raise CTERM ("norm_canon_conv", [ct])
   241 
   242 fun int_flip v eq =
   243   if FuncUtil.Intfunc.defined eq v
   244   then FuncUtil.Intfunc.update (v, Rat.neg (FuncUtil.Intfunc.apply eq v)) eq else eq;
   245 
   246 local
   247  val pth_zero = @{thm norm_zero}
   248  val tv_n = (ctyp_of_term o Thm.dest_arg o Thm.dest_arg1 o Thm.dest_arg o cprop_of)
   249              pth_zero
   250  val concl = Thm.dest_arg o cprop_of
   251  fun real_vector_combo_prover ctxt translator (nubs,ges,gts) =
   252   let
   253    (* FIXME: Should be computed statically!!*)
   254    val real_poly_conv =
   255       Semiring_Normalizer.semiring_normalize_wrapper ctxt
   256        (the (Semiring_Normalizer.match ctxt @{cterm "(0::real) + 1"}))
   257    val sources = map (Thm.dest_arg o Thm.dest_arg1 o concl) nubs
   258    val rawdests = fold_rev (find_normedterms o Thm.dest_arg o concl) (ges @ gts) []
   259    val _ = if not (forall fst rawdests) then error "real_vector_combo_prover: Sanity check"
   260            else ()
   261    val dests = distinct (op aconvc) (map snd rawdests)
   262    val srcfuns = map vector_lincomb sources
   263    val destfuns = map vector_lincomb dests
   264    val vvs = fold_rev (union (op aconvc) o FuncUtil.Ctermfunc.dom) (srcfuns @ destfuns) []
   265    val n = length srcfuns
   266    val nvs = 1 upto n
   267    val srccombs = srcfuns ~~ nvs
   268    fun consider d =
   269     let
   270      fun coefficients x =
   271       let
   272        val inp = if FuncUtil.Ctermfunc.defined d x then FuncUtil.Intfunc.onefunc (0, Rat.neg(FuncUtil.Ctermfunc.apply d x))
   273                       else FuncUtil.Intfunc.empty
   274       in fold_rev (fn (f,v) => fn g => if FuncUtil.Ctermfunc.defined f x then FuncUtil.Intfunc.update (v, FuncUtil.Ctermfunc.apply f x) g else g) srccombs inp
   275       end
   276      val equations = map coefficients vvs
   277      val inequalities = map (fn n => FuncUtil.Intfunc.onefunc (n,Rat.one)) nvs
   278      fun plausiblevertices f =
   279       let
   280        val flippedequations = map (fold_rev int_flip f) equations
   281        val constraints = flippedequations @ inequalities
   282        val rawverts = vertices nvs constraints
   283        fun check_solution v =
   284         let
   285           val f = fold_rev2 (curry FuncUtil.Intfunc.update) nvs v (FuncUtil.Intfunc.onefunc (0, Rat.one))
   286         in forall (fn e => evaluate f e =/ Rat.zero) flippedequations
   287         end
   288        val goodverts = filter check_solution rawverts
   289        val signfixups = map (fn n => if member (op =) f n then ~1 else 1) nvs
   290       in map (map2 (fn s => fn c => Rat.rat_of_int s */ c) signfixups) goodverts
   291       end
   292      val allverts = fold_rev append (map plausiblevertices (allsubsets nvs)) []
   293     in subsume allverts []
   294     end
   295    fun compute_ineq v =
   296     let
   297      val ths = map_filter (fn (v,t) => if v =/ Rat.zero then NONE
   298                                      else SOME(norm_cmul_rule v t))
   299                             (v ~~ nubs)
   300      fun end_itlist f xs = split_last xs |> uncurry (fold_rev f)
   301     in inequality_canon_rule ctxt (end_itlist norm_add_rule ths)
   302     end
   303    val ges' = map_filter (try compute_ineq) (fold_rev (append o consider) destfuns []) @
   304                  map (inequality_canon_rule ctxt) nubs @ ges
   305    val zerodests = filter
   306         (fn t => null (FuncUtil.Ctermfunc.dom (vector_lincomb t))) (map snd rawdests)
   307 
   308   in fst (RealArith.real_linear_prover translator
   309         (map (fn t => instantiate ([(tv_n, ctyp_of_term t)],[]) pth_zero)
   310             zerodests,
   311         map (fconv_rule (try_conv (Conv.top_sweep_conv (K norm_canon_conv) ctxt) then_conv
   312                        arg_conv (arg_conv real_poly_conv))) ges',
   313         map (fconv_rule (try_conv (Conv.top_sweep_conv (K norm_canon_conv) ctxt) then_conv
   314                        arg_conv (arg_conv real_poly_conv))) gts))
   315   end
   316 in val real_vector_combo_prover = real_vector_combo_prover
   317 end;
   318 
   319 local
   320  val pth = @{thm norm_imp_pos_and_ge}
   321  val norm_mp = match_mp pth
   322  val concl = Thm.dest_arg o cprop_of
   323  fun conjunct1 th = th RS @{thm conjunct1}
   324  fun conjunct2 th = th RS @{thm conjunct2}
   325 fun real_vector_ineq_prover ctxt translator (ges,gts) =
   326  let
   327 (*   val _ = error "real_vector_ineq_prover: pause" *)
   328   val ntms = fold_rev find_normedterms (map (Thm.dest_arg o concl) (ges @ gts)) []
   329   val lctab = vector_lincombs (map snd (filter (not o fst) ntms))
   330   val (fxns, ctxt') = Variable.variant_fixes (replicate (length lctab) "x") ctxt
   331   fun instantiate_cterm' ty tms = Drule.cterm_rule (Drule.instantiate' ty tms)
   332   fun mk_norm t = Thm.capply (instantiate_cterm' [SOME (ctyp_of_term t)] [] @{cpat "norm :: (?'a :: real_normed_vector) => real"}) t
   333   fun mk_equals l r = Thm.capply (Thm.capply (instantiate_cterm' [SOME (ctyp_of_term l)] [] @{cpat "op == :: ?'a =>_"}) l) r
   334   val asl = map2 (fn (t,_) => fn n => Thm.assume (mk_equals (mk_norm t) (cterm_of (Proof_Context.theory_of ctxt') (Free(n,@{typ real}))))) lctab fxns
   335   val replace_conv = try_conv (rewrs_conv asl)
   336   val replace_rule = fconv_rule (funpow 2 arg_conv (replacenegnorms replace_conv))
   337   val ges' =
   338        fold_rev (fn th => fn ths => conjunct1(norm_mp th)::ths)
   339               asl (map replace_rule ges)
   340   val gts' = map replace_rule gts
   341   val nubs = map (conjunct2 o norm_mp) asl
   342   val th1 = real_vector_combo_prover ctxt' translator (nubs,ges',gts')
   343   val shs = filter (member (fn (t,th) => t aconvc cprop_of th) asl) (#hyps (crep_thm th1))
   344   val th11 = hd (Variable.export ctxt' ctxt [fold Thm.implies_intr shs th1])
   345   val cps = map (swap o Thm.dest_equals) (cprems_of th11)
   346   val th12 = instantiate ([], cps) th11
   347   val th13 = fold Thm.elim_implies (map (Thm.reflexive o snd) cps) th12;
   348  in hd (Variable.export ctxt' ctxt [th13])
   349  end
   350 in val real_vector_ineq_prover = real_vector_ineq_prover
   351 end;
   352 
   353 local
   354  val rawrule = fconv_rule (arg_conv (rewr_conv @{thm real_eq_0_iff_le_ge_0}))
   355  fun conj_pair th = (th RS @{thm conjunct1}, th RS @{thm conjunct2})
   356  fun simple_cterm_ord t u = Term_Ord.term_ord (term_of t, term_of u) = LESS;
   357   (* FIXME: Lookup in the context every time!!! Fix this !!!*)
   358  fun splitequation ctxt th acc =
   359   let
   360    val real_poly_neg_conv = #neg
   361        (Semiring_Normalizer.semiring_normalizers_ord_wrapper ctxt
   362         (the (Semiring_Normalizer.match ctxt @{cterm "(0::real) + 1"})) simple_cterm_ord)
   363    val (th1,th2) = conj_pair(rawrule th)
   364   in th1::fconv_rule (arg_conv (arg_conv real_poly_neg_conv)) th2::acc
   365   end
   366 in fun real_vector_prover ctxt _ translator (eqs,ges,gts) =
   367      (real_vector_ineq_prover ctxt translator
   368          (fold_rev (splitequation ctxt) eqs ges,gts), RealArith.Trivial)
   369 end;
   370 
   371   fun init_conv ctxt =
   372    Simplifier.rewrite (Simplifier.context ctxt
   373      (HOL_basic_ss addsimps ([(*@{thm vec_0}, @{thm vec_1},*) @{thm dist_norm}, @{thm diff_0_right}, @{thm right_minus}, @{thm diff_self}, @{thm norm_zero}] @ @{thms arithmetic_simps} @ @{thms norm_pths})))
   374    then_conv Numeral_Simprocs.field_comp_conv
   375    then_conv nnf_conv
   376 
   377  fun pure ctxt = fst o RealArith.gen_prover_real_arith ctxt (real_vector_prover ctxt);
   378  fun norm_arith ctxt ct =
   379   let
   380    val ctxt' = Variable.declare_term (term_of ct) ctxt
   381    val th = init_conv ctxt' ct
   382   in Thm.equal_elim (Drule.arg_cong_rule @{cterm Trueprop} (Thm.symmetric th))
   383                 (pure ctxt' (Thm.rhs_of th))
   384  end
   385 
   386  fun norm_arith_tac ctxt =
   387    clarify_tac (put_claset HOL_cs ctxt) THEN'
   388    Object_Logic.full_atomize_tac THEN'
   389    CSUBGOAL ( fn (p,i) => rtac (norm_arith ctxt (Thm.dest_arg p )) i);
   390 
   391 end;