src/Provers/Arith/fast_lin_arith.ML
changeset 66035 de6cd60b1226
parent 63227 d3ed7f00e818
child 67649 1e1782c1aedf
     1.1 --- a/src/Provers/Arith/fast_lin_arith.ML	Wed Jun 07 23:23:48 2017 +0200
     1.2 +++ b/src/Provers/Arith/fast_lin_arith.ML	Thu Jun 08 23:37:01 2017 +0200
     1.3 @@ -49,6 +49,12 @@
     1.4    val decomp: Proof.context -> term -> decomp option
     1.5    val domain_is_nat: term -> bool
     1.6  
     1.7 +  (*abstraction for proof replay*)
     1.8 +  val abstract_arith: term -> (term * term) list * Proof.context ->
     1.9 +    term * ((term * term) list * Proof.context)
    1.10 +  val abstract: term -> (term * term) list * Proof.context ->
    1.11 +    term * ((term * term) list * Proof.context)
    1.12 +
    1.13    (*preprocessing, performed on a representation of subgoals as list of premises:*)
    1.14    val pre_decomp: Proof.context -> typ list * term list -> (typ list * term list) list
    1.15  
    1.16 @@ -288,7 +294,7 @@
    1.17  fun extract_first p =
    1.18    let
    1.19      fun extract xs (y::ys) = if p y then (y, xs @ ys) else extract (y::xs) ys
    1.20 -      | extract xs [] = raise List.Empty
    1.21 +      | extract _ [] = raise List.Empty
    1.22    in extract [] end;
    1.23  
    1.24  fun print_ineqs ctxt ineqs =
    1.25 @@ -373,7 +379,7 @@
    1.26  with 0 <= n.
    1.27  *)
    1.28  local
    1.29 -  exception FalseE of thm
    1.30 +  exception FalseE of thm * (int * cterm) list * Proof.context
    1.31  in
    1.32  
    1.33  fun mkthm ctxt asms (just: injust) =
    1.34 @@ -439,29 +445,53 @@
    1.35                   | THM _ => mult_by_add n thm
    1.36            end);
    1.37  
    1.38 -    fun mult_thm (n, thm) =
    1.39 +    fun mult_thm n thm =
    1.40        if n = ~1 then thm RS LA_Logic.sym
    1.41        else if n < 0 then mult (~n) thm RS LA_Logic.sym
    1.42        else mult n thm;
    1.43  
    1.44 -    fun simp thm =
    1.45 +    fun simp thm (cx as (_, hyps, ctxt')) =
    1.46        let val thm' = trace_thm ctxt ["Simplified:"] (full_simplify simpset_ctxt thm)
    1.47 -      in if LA_Logic.is_False thm' then raise FalseE thm' else thm' end;
    1.48 +      in if LA_Logic.is_False thm' then raise FalseE (thm', hyps, ctxt') else (thm', cx) end;
    1.49 +
    1.50 +    fun abs_thm i (cx as (terms, hyps, ctxt)) =
    1.51 +      (case AList.lookup (op =) hyps i of
    1.52 +        SOME ct => (Thm.assume ct, cx)
    1.53 +      | NONE =>
    1.54 +          let
    1.55 +            val thm = nth asms i
    1.56 +            val (t, (terms', ctxt')) = LA_Data.abstract (Thm.prop_of thm) (terms, ctxt)
    1.57 +            val ct = Thm.cterm_of ctxt' t
    1.58 +          in (Thm.assume ct, (terms', (i, ct) :: hyps, ctxt')) end);
    1.59 +
    1.60 +    fun nat_thm t (terms, hyps, ctxt) =
    1.61 +      let val (t', (terms', ctxt')) = LA_Data.abstract_arith t (terms, ctxt)
    1.62 +      in (LA_Logic.mk_nat_thm thy t', (terms', hyps, ctxt')) end;
    1.63  
    1.64 -    fun mk (Asm i) = trace_thm ctxt ["Asm " ^ string_of_int i] (nth asms i)
    1.65 -      | mk (Nat i) = trace_thm ctxt ["Nat " ^ string_of_int i] (LA_Logic.mk_nat_thm thy (nth atoms i))
    1.66 -      | mk (LessD j) = trace_thm ctxt ["L"] (hd ([mk j] RL lessD))
    1.67 -      | mk (NotLeD j) = trace_thm ctxt ["NLe"] (mk j RS LA_Logic.not_leD)
    1.68 -      | mk (NotLeDD j) = trace_thm ctxt ["NLeD"] (hd ([mk j RS LA_Logic.not_leD] RL lessD))
    1.69 -      | mk (NotLessD j) = trace_thm ctxt ["NL"] (mk j RS LA_Logic.not_lessD)
    1.70 -      | mk (Added (j1, j2)) = simp (trace_thm ctxt ["+"] (add_thms (mk j1) (mk j2)))
    1.71 -      | mk (Multiplied (n, j)) =
    1.72 -          (trace_msg ctxt ("*" ^ string_of_int n); trace_thm ctxt ["*"] (mult_thm (n, mk j)))
    1.73 +    fun step0 msg (thm, cx) = (trace_thm ctxt [msg] thm, cx)
    1.74 +    fun step1 msg j f cx = mk j cx |>> f |>> trace_thm ctxt [msg]
    1.75 +    and step2 msg j1 j2 f cx = mk j1 cx ||>> mk j2 |>> f |>> trace_thm ctxt [msg]
    1.76  
    1.77 +    and mk (Asm i) cx = step0 ("Asm " ^ string_of_int i) (abs_thm i cx)
    1.78 +      | mk (Nat i) cx = step0 ("Nat " ^ string_of_int i) (nat_thm (nth atoms i) cx)
    1.79 +      | mk (LessD j) cx = step1 "L" j (fn thm => hd ([thm] RL lessD)) cx
    1.80 +      | mk (NotLeD j) cx = step1 "NLe" j (fn thm => thm RS LA_Logic.not_leD) cx
    1.81 +      | mk (NotLeDD j) cx = step1 "NLeD" j (fn thm => hd ([thm RS LA_Logic.not_leD] RL lessD)) cx
    1.82 +      | mk (NotLessD j) cx = step1 "NL" j (fn thm => thm RS LA_Logic.not_lessD) cx
    1.83 +      | mk (Added (j1, j2)) cx = step2 "+" j1 j2 (uncurry add_thms) cx |-> simp
    1.84 +      | mk (Multiplied (n, j)) cx =
    1.85 +          (trace_msg ctxt ("*" ^ string_of_int n); step1 "*" j (mult_thm n) cx)
    1.86 +
    1.87 +    fun finish ctxt' hyps thm =
    1.88 +      thm
    1.89 +      |> fold_rev (Thm.implies_intr o snd) hyps
    1.90 +      |> singleton (Variable.export ctxt' ctxt)
    1.91 +      |> fold (fn (i, _) => fn thm => nth asms i RS thm) hyps
    1.92    in
    1.93      let
    1.94        val _ = trace_msg ctxt "mkthm";
    1.95 -      val thm = trace_thm ctxt ["Final thm:"] (mk just);
    1.96 +      val (thm, (_, hyps, ctxt')) = mk just ([], [], ctxt);
    1.97 +      val _ = trace_thm ctxt ["Final thm:"] thm;
    1.98        val fls = simplify simpset_ctxt thm;
    1.99        val _ = trace_thm ctxt ["After simplification:"] fls;
   1.100        val _ =
   1.101 @@ -472,8 +502,9 @@
   1.102              ["Proved:", Thm.string_of_thm ctxt fls, ""]));
   1.103            warning "Linear arithmetic should have refuted the assumptions.\n\
   1.104              \Please inform Tobias Nipkow.")
   1.105 -    in fls end
   1.106 -    handle FalseE thm => trace_thm ctxt ["False reached early:"] thm
   1.107 +    in finish ctxt' hyps fls end
   1.108 +    handle FalseE (thm, hyps, ctxt') =>
   1.109 +      trace_thm ctxt ["False reached early:"] (finish ctxt' hyps thm)
   1.110    end;
   1.111  
   1.112  end;
   1.113 @@ -555,7 +586,7 @@
   1.114    fun elim_neq (ineqs : (LA_Data.decomp option * bool) list) :
   1.115                 (LA_Data.decomp option * bool) list list =
   1.116    let
   1.117 -    fun elim_neq' nat_only ([] : (LA_Data.decomp option * bool) list) :
   1.118 +    fun elim_neq' _ ([] : (LA_Data.decomp option * bool) list) :
   1.119                    (LA_Data.decomp option * bool) list list =
   1.120            [[]]
   1.121        | elim_neq' nat_only ((NONE, is_nat) :: ineqs) =