src/Provers/order_tac.ML
changeset 73526 a3cc9fa1295d
child 74561 8e6c973003c8
equal deleted inserted replaced
73517:d3f2038198ae 73526:a3cc9fa1295d
       
     1 signature REIFY_TABLE =
       
     2 sig
       
     3   type table
       
     4   val empty : table
       
     5   val get_var : term -> table -> (int * table)
       
     6   val get_term : int -> table -> term option
       
     7 end
       
     8 
       
     9 structure Reifytab: REIFY_TABLE =
       
    10 struct
       
    11   type table = (int * int Termtab.table * term Inttab.table)
       
    12   
       
    13   val empty = (0, Termtab.empty, Inttab.empty)
       
    14   
       
    15   fun get_var t (tab as (max_var, termtab, inttab)) =
       
    16     (case Termtab.lookup termtab t of
       
    17       SOME v => (v, tab)
       
    18     | NONE => (max_var,
       
    19               (max_var + 1, Termtab.update (t, max_var) termtab, Inttab.update (max_var, t) inttab))
       
    20     )
       
    21   
       
    22   fun get_term v (_, _, inttab) = Inttab.lookup inttab v
       
    23 end
       
    24 
       
    25 signature LOGIC_SIGNATURE =
       
    26 sig
       
    27   val mk_Trueprop : term -> term
       
    28   val dest_Trueprop : term -> term
       
    29   val Trueprop_conv : conv -> conv
       
    30   val Not : term
       
    31   val conj : term
       
    32   val disj : term
       
    33   
       
    34   val notI : thm (* (P \<Longrightarrow> False) \<Longrightarrow> \<not> P *)
       
    35   val ccontr : thm (* (\<not> P \<Longrightarrow> False) \<Longrightarrow> P *)
       
    36   val conjI : thm (* P \<Longrightarrow> Q \<Longrightarrow> P \<and> Q *)
       
    37   val conjE : thm (* P \<and> Q \<Longrightarrow> (P \<Longrightarrow> Q \<Longrightarrow> R) \<Longrightarrow> R *)
       
    38   val disjE : thm (* P \<or> Q \<Longrightarrow> (P \<Longrightarrow> R) \<Longrightarrow> (Q \<Longrightarrow> R) \<Longrightarrow> R *)
       
    39 
       
    40   val not_not_conv : conv (* \<not> (\<not> P) \<equiv> P *)
       
    41   val de_Morgan_conj_conv : conv (* \<not> (P \<and> Q) \<equiv> \<not> P \<or> \<not> Q *)
       
    42   val de_Morgan_disj_conv : conv (* \<not> (P \<or> Q) \<equiv> \<not> P \<and> \<not> Q *)
       
    43   val conj_disj_distribL_conv : conv (* P \<and> (Q \<or> R) \<equiv> (P \<and> Q) \<or> (P \<and> R) *)
       
    44   val conj_disj_distribR_conv : conv (* (Q \<or> R) \<and> P \<equiv> (Q \<and> P) \<or> (R \<and> P) *)
       
    45 end
       
    46 
       
    47 (* Control tracing output of the solver. *)
       
    48 val order_trace_cfg = Attrib.setup_config_bool @{binding "order_trace"} (K false)
       
    49 (* In partial orders, literals of the form \<not> x < y will force the order solver to perform case
       
    50    distinctions, which leads to an exponential blowup of the runtime. The split limit controls
       
    51    the number of literals of this form that are passed to the solver. 
       
    52  *)
       
    53 val order_split_limit_cfg = Attrib.setup_config_int @{binding "order_split_limit"} (K 8)
       
    54 
       
    55 datatype order_kind = Order | Linorder
       
    56 
       
    57 type order_literal = (bool * Order_Procedure.o_atom)
       
    58 
       
    59 type order_context = {
       
    60     kind : order_kind,
       
    61     ops : term list, thms : (string * thm) list, conv_thms : (string * thm) list
       
    62   }
       
    63 
       
    64 signature BASE_ORDER_TAC =
       
    65 sig
       
    66 
       
    67   val tac :
       
    68         (order_literal Order_Procedure.fm -> Order_Procedure.prf_trm option)
       
    69         -> order_context -> thm list
       
    70         -> Proof.context -> int -> tactic
       
    71 end
       
    72 
       
    73 functor Base_Order_Tac(
       
    74   structure Logic_Sig : LOGIC_SIGNATURE; val excluded_types : typ list) : BASE_ORDER_TAC =
       
    75 struct
       
    76   open Order_Procedure
       
    77 
       
    78   fun expect _ (SOME x) = x
       
    79     | expect f NONE = f ()
       
    80 
       
    81   fun matches_skeleton t s = t = Term.dummy orelse
       
    82     (case (t, s) of
       
    83       (t0 $ t1, s0 $ s1) => matches_skeleton t0 s0 andalso matches_skeleton t1 s1
       
    84     | _ => t aconv s)
       
    85 
       
    86   fun dest_binop t =
       
    87     let
       
    88       val binop_skel = Term.dummy $ Term.dummy $ Term.dummy
       
    89       val not_binop_skel = Logic_Sig.Not $ binop_skel
       
    90     in
       
    91       if matches_skeleton not_binop_skel t
       
    92         then (case t of (_ $ (t1 $ t2 $ t3)) => (false, (t1, t2, t3)))
       
    93         else if matches_skeleton binop_skel t
       
    94           then (case t of (t1 $ t2 $ t3) => (true, (t1, t2, t3)))
       
    95           else raise TERM ("Not a binop literal", [t])
       
    96     end
       
    97 
       
    98   fun find_term t = Library.find_first (fn (t', _) => t' aconv t)
       
    99 
       
   100   fun reify_order_atom (eq, le, lt) t reifytab =
       
   101     let
       
   102       val (b, (t0, t1, t2)) =
       
   103         (dest_binop t) handle TERM (_, _) => raise TERM ("Can't reify order literal", [t])
       
   104       val binops = [(eq, EQ), (le, LEQ), (lt, LESS)]
       
   105     in
       
   106       case find_term t0 binops of
       
   107         SOME (_, reified_bop) =>
       
   108           reifytab
       
   109           |> Reifytab.get_var t1 ||> Reifytab.get_var t2
       
   110           |> (fn (v1, (v2, vartab')) =>
       
   111                ((b, reified_bop (Int_of_integer v1, Int_of_integer v2)), vartab'))
       
   112           |>> Atom
       
   113       | NONE => raise TERM ("Can't reify order literal", [t])
       
   114     end
       
   115 
       
   116   fun reify consts reify_atom t reifytab =
       
   117     let
       
   118       fun reify' (t1 $ t2) reifytab =
       
   119             let
       
   120               val (t0, ts) = strip_comb (t1 $ t2)
       
   121               val consts_of_arity = filter (fn (_, (_, ar)) => length ts = ar) consts
       
   122             in
       
   123               (case find_term t0 consts_of_arity of
       
   124                 SOME (_, (reified_op, _)) => fold_map reify' ts reifytab |>> reified_op
       
   125               | NONE => reify_atom (t1 $ t2) reifytab)
       
   126             end
       
   127         | reify' t reifytab = reify_atom t reifytab
       
   128     in
       
   129       reify' t reifytab
       
   130     end
       
   131 
       
   132   fun list_curry0 f = (fn [] => f, 0)
       
   133   fun list_curry1 f = (fn [x] => f x, 1)
       
   134   fun list_curry2 f = (fn [x, y] => f x y, 2)
       
   135 
       
   136   fun reify_order_conj ord_ops =
       
   137     let
       
   138       val consts = map (apsnd (list_curry2 o curry)) [(Logic_Sig.conj, And), (Logic_Sig.disj, Or)]
       
   139     in   
       
   140       reify consts (reify_order_atom ord_ops)
       
   141     end
       
   142 
       
   143   fun dereify_term consts reifytab t =
       
   144     let
       
   145       fun dereify_term' (App (t1, t2)) = (dereify_term' t1) $ (dereify_term' t2)
       
   146         | dereify_term' (Const s) =
       
   147             AList.lookup (op =) consts s
       
   148             |> expect (fn () => raise TERM ("Const " ^ s ^ " not in", map snd consts))
       
   149         | dereify_term' (Var v) = Reifytab.get_term (integer_of_int v) reifytab |> the
       
   150     in
       
   151       dereify_term' t
       
   152     end
       
   153 
       
   154   fun dereify_order_fm (eq, le, lt) reifytab t =
       
   155     let
       
   156       val consts = [
       
   157         ("eq", eq), ("le", le), ("lt", lt),
       
   158         ("Not", Logic_Sig.Not), ("disj", Logic_Sig.disj), ("conj", Logic_Sig.conj)
       
   159         ]
       
   160     in
       
   161       dereify_term consts reifytab t
       
   162     end
       
   163 
       
   164   fun strip_AppP t =
       
   165     let fun strip (AppP (f, s), ss) = strip (f, s::ss)
       
   166           | strip x = x
       
   167     in strip (t, []) end
       
   168 
       
   169   fun replay_conv convs cvp =
       
   170     let
       
   171       val convs = convs @
       
   172         [("all_conv", list_curry0 Conv.all_conv)] @ 
       
   173         map (apsnd list_curry1) [
       
   174           ("atom_conv", I),
       
   175           ("neg_atom_conv", I),
       
   176           ("arg_conv", Conv.arg_conv)] @
       
   177         map (apsnd list_curry2) [
       
   178           ("combination_conv", Conv.combination_conv),
       
   179           ("then_conv", curry (op then_conv))]
       
   180 
       
   181       fun lookup_conv convs c = AList.lookup (op =) convs c
       
   182             |> expect (fn () => error ("Can't replay conversion: " ^ c))
       
   183 
       
   184       fun rp_conv t =
       
   185         (case strip_AppP t ||> map rp_conv of
       
   186           (PThm c, cvs) =>
       
   187             let val (conv, arity) = lookup_conv convs c
       
   188             in if arity = length cvs
       
   189               then conv cvs
       
   190               else error ("Expected " ^ Int.toString arity ^ " arguments for conversion " ^
       
   191                           c ^ " but got " ^ (length cvs |> Int.toString) ^ " arguments")
       
   192             end
       
   193         | _ => error "Unexpected constructor in conversion proof")
       
   194     in
       
   195       rp_conv cvp
       
   196     end
       
   197 
       
   198   fun replay_prf_trm replay_conv dereify ctxt thmtab assmtab p =
       
   199     let
       
   200       fun replay_prf_trm' _ (PThm s) =
       
   201             AList.lookup (op =) thmtab s
       
   202             |> expect (fn () => error ("Cannot replay theorem: " ^ s))
       
   203         | replay_prf_trm' assmtab (Appt (p, t)) =
       
   204             replay_prf_trm' assmtab p
       
   205             |> Drule.infer_instantiate' ctxt [SOME (Thm.cterm_of ctxt (dereify t))]
       
   206         | replay_prf_trm' assmtab (AppP (p1, p2)) =
       
   207             apply2 (replay_prf_trm' assmtab) (p2, p1) |> (op COMP)
       
   208         | replay_prf_trm' assmtab (AbsP (reified_t, p)) =
       
   209             let
       
   210               val t = dereify reified_t
       
   211               val t_thm = Logic_Sig.mk_Trueprop t |> Thm.cterm_of ctxt |> Assumption.assume ctxt
       
   212               val rp = replay_prf_trm' (Termtab.update (Thm.prop_of t_thm, t_thm) assmtab) p
       
   213             in
       
   214               Thm.implies_intr (Thm.cprop_of t_thm) rp
       
   215             end
       
   216         | replay_prf_trm' assmtab (Bound reified_t) =
       
   217             let
       
   218               val t = dereify reified_t |> Logic_Sig.mk_Trueprop
       
   219             in
       
   220               Termtab.lookup assmtab t
       
   221               |> expect (fn () => raise TERM ("Assumption not found:", t::Termtab.keys assmtab))
       
   222             end
       
   223         | replay_prf_trm' assmtab (Conv (t, cp, p)) =
       
   224             let
       
   225               val thm = replay_prf_trm' assmtab (Bound t)
       
   226               val conv = Logic_Sig.Trueprop_conv (replay_conv cp)
       
   227               val conv_thm = Conv.fconv_rule conv thm
       
   228               val conv_term = Thm.prop_of conv_thm
       
   229             in
       
   230               replay_prf_trm' (Termtab.update (conv_term, conv_thm) assmtab) p
       
   231             end
       
   232     in
       
   233       replay_prf_trm' assmtab p
       
   234     end
       
   235 
       
   236   fun replay_order_prf_trm ord_ops {thms = thms, conv_thms = conv_thms, ...} ctxt reifytab assmtab =
       
   237     let
       
   238       val thmtab = thms @ [
       
   239           ("conjE", Logic_Sig.conjE), ("conjI", Logic_Sig.conjI), ("disjE", Logic_Sig.disjE)
       
   240         ]
       
   241       val convs = map (apsnd list_curry0) (
       
   242         map (apsnd Conv.rewr_conv) conv_thms @
       
   243         [
       
   244           ("not_not_conv", Logic_Sig.not_not_conv),
       
   245           ("de_Morgan_conj_conv", Logic_Sig.de_Morgan_conj_conv),
       
   246           ("de_Morgan_disj_conv", Logic_Sig.de_Morgan_disj_conv),
       
   247           ("conj_disj_distribR_conv", Logic_Sig.conj_disj_distribR_conv),
       
   248           ("conj_disj_distribL_conv", Logic_Sig.conj_disj_distribL_conv)
       
   249         ])
       
   250       
       
   251       val dereify = dereify_order_fm ord_ops reifytab
       
   252     in
       
   253       replay_prf_trm (replay_conv convs) dereify ctxt thmtab assmtab
       
   254     end
       
   255 
       
   256   fun is_binop_term t =
       
   257     let
       
   258       fun is_included t = forall (curry (op <>) (t |> fastype_of |> domain_type)) excluded_types
       
   259     in
       
   260       (case dest_binop (Logic_Sig.dest_Trueprop t) of
       
   261         (_, (binop, t1, t2)) =>
       
   262           is_included binop andalso
       
   263           (* Exclude terms with schematic variables since the solver can't deal with them.
       
   264              More specifically, the solver uses Assumption.assume which does not allow schematic
       
   265              variables in the assumed cterm.
       
   266           *)
       
   267           Term.add_var_names (binop $ t1 $ t2) [] = []
       
   268       ) handle TERM (_, _) => false
       
   269     end
       
   270 
       
   271   fun partition_matches ctxt term_of pats ys =
       
   272     let
       
   273       val thy = Proof_Context.theory_of ctxt
       
   274 
       
   275       fun find_match t env =
       
   276         Library.get_first (try (fn pat => Pattern.match thy (pat, t) env)) pats
       
   277       
       
   278       fun filter_matches xs = fold (fn x => fn (mxs, nmxs, env) =>
       
   279         case find_match (term_of x) env of
       
   280           SOME env' => (x::mxs, nmxs, env')
       
   281         | NONE => (mxs, x::nmxs, env)) xs ([], [], (Vartab.empty, Vartab.empty))
       
   282 
       
   283       fun partition xs =
       
   284         case filter_matches xs of
       
   285           ([], _, _) => []
       
   286         | (mxs, nmxs, env) => (env, mxs) :: partition nmxs
       
   287     in
       
   288       partition ys
       
   289     end
       
   290 
       
   291   fun limit_not_less [_, _, lt] ctxt prems =
       
   292     let
       
   293       val thy = Proof_Context.theory_of ctxt
       
   294       val trace = Config.get ctxt order_trace_cfg
       
   295       val limit = Config.get ctxt order_split_limit_cfg
       
   296 
       
   297       fun is_not_less_term t =
       
   298         (case dest_binop (Logic_Sig.dest_Trueprop t) of
       
   299           (false, (t0, _, _)) => Pattern.matches thy (lt, t0)
       
   300         | _ => false)
       
   301         handle TERM _ => false
       
   302 
       
   303       val not_less_prems = filter (is_not_less_term o Thm.prop_of) prems
       
   304       val _ = if trace andalso length not_less_prems > limit
       
   305                 then tracing "order split limit exceeded"
       
   306                 else ()
       
   307      in
       
   308       filter_out (is_not_less_term o Thm.prop_of) prems @
       
   309       take limit not_less_prems
       
   310      end
       
   311       
       
   312   fun order_tac raw_order_proc octxt simp_prems =
       
   313     Subgoal.FOCUS (fn {prems=prems, context=ctxt, ...} =>
       
   314       let
       
   315         val trace = Config.get ctxt order_trace_cfg
       
   316 
       
   317         val binop_prems = filter (is_binop_term o Thm.prop_of) (prems @ simp_prems)
       
   318         val strip_binop = (fn (x, _, _) => x) o snd o dest_binop
       
   319         val binop_of = strip_binop o Logic_Sig.dest_Trueprop o Thm.prop_of
       
   320 
       
   321         (* Due to local_setup, the operators of the order may contain schematic term and type
       
   322            variables. We partition the premises according to distinct instances of those operators.
       
   323          *)
       
   324         val part_prems = partition_matches ctxt binop_of (#ops octxt) binop_prems
       
   325           |> (case #kind octxt of
       
   326                 Order => map (fn (env, prems) =>
       
   327                           (env, limit_not_less (#ops octxt) ctxt prems))
       
   328               | _ => I)
       
   329               
       
   330         fun order_tac' (_, []) = no_tac
       
   331           | order_tac' (env, prems) =
       
   332             let
       
   333               val [eq, le, lt] = #ops octxt
       
   334               val subst_contract = Envir.eta_contract o Envir.subst_term env
       
   335               val ord_ops = (subst_contract eq,
       
   336                              subst_contract le,
       
   337                              subst_contract lt)
       
   338   
       
   339               val _ = if trace then @{print} (ord_ops, prems) else (ord_ops, prems)
       
   340   
       
   341               val prems_conj_thm = foldl1 (fn (x, a) => Logic_Sig.conjI OF [x, a]) prems
       
   342                 |> Conv.fconv_rule Thm.eta_conversion 
       
   343               val prems_conj = prems_conj_thm |> Thm.prop_of
       
   344               val (reified_prems_conj, reifytab) =
       
   345                 reify_order_conj ord_ops (Logic_Sig.dest_Trueprop prems_conj) Reifytab.empty
       
   346   
       
   347               val proof = raw_order_proc reified_prems_conj
       
   348   
       
   349               val assmtab = Termtab.make [(prems_conj, prems_conj_thm)]
       
   350               val replay = replay_order_prf_trm ord_ops octxt ctxt reifytab assmtab
       
   351             in
       
   352               case proof of
       
   353                 NONE => no_tac
       
   354               | SOME p => SOLVED' (resolve_tac ctxt [replay p]) 1
       
   355             end
       
   356      in
       
   357       FIRST (map order_tac' part_prems)
       
   358      end)
       
   359 
       
   360   val ad_absurdum_tac = SUBGOAL (fn (A, i) =>
       
   361       case try (Logic_Sig.dest_Trueprop o Logic.strip_assums_concl) A of
       
   362         SOME (nt $ _) =>
       
   363           if nt = Logic_Sig.Not
       
   364             then resolve0_tac [Logic_Sig.notI] i
       
   365             else resolve0_tac [Logic_Sig.ccontr] i
       
   366       | SOME _ => resolve0_tac [Logic_Sig.ccontr] i
       
   367       | NONE => resolve0_tac [Logic_Sig.ccontr] i)
       
   368 
       
   369   fun tac raw_order_proc octxt simp_prems ctxt =
       
   370       EVERY' [
       
   371           ad_absurdum_tac,
       
   372           CONVERSION Thm.eta_conversion,
       
   373           order_tac raw_order_proc octxt simp_prems ctxt
       
   374         ]
       
   375   
       
   376 end
       
   377 
       
   378 functor Order_Tac(structure Base_Tac : BASE_ORDER_TAC) = struct
       
   379 
       
   380   fun order_context_eq ({kind = kind1, ops = ops1, ...}, {kind = kind2, ops = ops2, ...}) =
       
   381     kind1 = kind2 andalso eq_list (op aconv) (ops1, ops2)
       
   382 
       
   383   fun order_data_eq (x, y) = order_context_eq (fst x, fst y)
       
   384   
       
   385   structure Data = Generic_Data(
       
   386     type T = (order_context * (order_context -> thm list -> Proof.context -> int -> tactic)) list
       
   387     val empty = []
       
   388     val extend = I
       
   389     fun merge data = Library.merge order_data_eq data
       
   390   )
       
   391 
       
   392   fun declare (octxt as {kind = kind, raw_proc = raw_proc, ...}) lthy =
       
   393     lthy |> Local_Theory.declaration {syntax = false, pervasive = false} (fn phi => fn context =>
       
   394       let
       
   395         val ops = map (Morphism.term phi) (#ops octxt)
       
   396         val thms = map (fn (s, thm) => (s, Morphism.thm phi thm)) (#thms octxt)
       
   397         val conv_thms = map (fn (s, thm) => (s, Morphism.thm phi thm)) (#conv_thms octxt)
       
   398         val octxt' = {kind = kind, ops = ops, thms = thms, conv_thms = conv_thms}
       
   399       in
       
   400         context |> Data.map (Library.insert order_data_eq (octxt', raw_proc))
       
   401       end)
       
   402 
       
   403   fun declare_order {
       
   404       ops = {eq = eq, le = le, lt = lt},
       
   405       thms = {
       
   406         trans = trans, (* x \<le> y \<Longrightarrow> y \<le> z \<Longrightarrow> x \<le> z *)
       
   407         refl = refl, (* x \<le> x *)
       
   408         eqD1 = eqD1, (* x = y \<Longrightarrow> x \<le> y *)
       
   409         eqD2 = eqD2, (* x = y \<Longrightarrow> y \<le> x *)
       
   410         antisym = antisym, (* x \<le> y \<Longrightarrow> y \<le> x \<Longrightarrow> x = y *)
       
   411         contr = contr (* \<not> P \<Longrightarrow> P \<Longrightarrow> R *)
       
   412       },
       
   413       conv_thms = {
       
   414         less_le = less_le, (* x < y \<equiv> x \<le> y \<and> x \<noteq> y *)
       
   415         nless_le = nless_le (* \<not> a < b \<equiv> \<not> a \<le> b \<or> a = b *)
       
   416       }
       
   417     } =
       
   418     declare {
       
   419       kind = Order,
       
   420       ops = [eq, le, lt],
       
   421       thms = [("trans", trans), ("refl", refl), ("eqD1", eqD1), ("eqD2", eqD2),
       
   422               ("antisym", antisym), ("contr", contr)],
       
   423       conv_thms = [("less_le", less_le), ("nless_le", nless_le)],
       
   424       raw_proc = Base_Tac.tac Order_Procedure.po_contr_prf
       
   425      }                
       
   426 
       
   427   fun declare_linorder {
       
   428       ops = {eq = eq, le = le, lt = lt},
       
   429       thms = {
       
   430         trans = trans, (* x \<le> y \<Longrightarrow> y \<le> z \<Longrightarrow> x \<le> z *)
       
   431         refl = refl, (* x \<le> x *)
       
   432         eqD1 = eqD1, (* x = y \<Longrightarrow> x \<le> y *)
       
   433         eqD2 = eqD2, (* x = y \<Longrightarrow> y \<le> x *)
       
   434         antisym = antisym, (* x \<le> y \<Longrightarrow> y \<le> x \<Longrightarrow> x = y *)
       
   435         contr = contr (* \<not> P \<Longrightarrow> P \<Longrightarrow> R *)
       
   436       },
       
   437       conv_thms = {
       
   438         less_le = less_le, (* x < y \<equiv> x \<le> y \<and> x \<noteq> y *)
       
   439         nless_le = nless_le, (* \<not> x < y \<equiv> y \<le> x *)
       
   440         nle_le = nle_le (* \<not> a \<le> b \<equiv> b \<le> a \<and> b \<noteq> a *)
       
   441       }
       
   442     } =
       
   443     declare {
       
   444       kind = Linorder,
       
   445       ops = [eq, le, lt],
       
   446       thms = [("trans", trans), ("refl", refl), ("eqD1", eqD1), ("eqD2", eqD2),
       
   447               ("antisym", antisym), ("contr", contr)],
       
   448       conv_thms = [("less_le", less_le), ("nless_le", nless_le), ("nle_le", nle_le)],
       
   449       raw_proc = Base_Tac.tac Order_Procedure.lo_contr_prf
       
   450      }
       
   451   
       
   452   (* Try to solve the goal by calling the order solver with each of the declared orders. *)      
       
   453   fun tac simp_prems ctxt =
       
   454     let fun app_tac (octxt, tac0) = CHANGED o tac0 octxt simp_prems ctxt
       
   455     in FIRST' (map app_tac (Data.get (Context.Proof ctxt))) end
       
   456 end