src/Pure/Isar/rule_insts.ML
author wenzelm
Thu, 23 Nov 2006 20:33:41 +0100
changeset 21500 146938537ddc
parent 20548 8ef25fe585a8
child 21879 a3efbae45735
permissions -rw-r--r--
renamed Args.Name to Args.Text;

(*  Title:      Pure/Isar/rule_insts.ML
    ID:         $Id$
    Author:     Makarius

Rule instantiations -- operations within a rule/subgoal context.
*)

signature RULE_INSTS =
sig
  val bires_inst_tac: bool -> Proof.context -> (indexname * string) list ->
    thm -> int -> tactic
end;

structure RuleInsts: RULE_INSTS =
struct


(** reading instantiations **)

local

fun is_tvar (x, _) = String.isPrefix "'" x;

fun error_var msg xi = error (msg ^ Syntax.string_of_vname xi);

fun the_sort tvars xi = the (AList.lookup (op =) tvars xi)
  handle Option.Option => error_var "No such type variable in theorem: " xi;

fun the_type vars xi = the (AList.lookup (op =) vars xi)
  handle Option.Option => error_var "No such variable in theorem: " xi;

fun unify_vartypes thy vars (xi, u) (unifier, maxidx) =
  let
    val T = the_type vars xi;
    val U = Term.fastype_of u;
    val maxidx' = Term.maxidx_term u (Int.max (#2 xi, maxidx));
  in
    Sign.typ_unify thy (T, U) (unifier, maxidx')
      handle Type.TUNIFY => error_var "Incompatible type for instantiation of " xi
  end;

fun instantiate inst =
  TermSubst.instantiate ([], map (fn (xi, t) => ((xi, Term.fastype_of t), t)) inst) #>
  Envir.beta_norm;

fun make_instT f v =
  let
    val T = TVar v;
    val T' = f T;
  in if T = T' then NONE else SOME (T, T') end;

fun make_inst f v =
  let
    val t = Var v;
    val t' = f t;
  in if t aconv t' then NONE else SOME (t, t') end;

in

fun read_insts ctxt mixed_insts (tvars, vars) =
  let
    val thy = ProofContext.theory_of ctxt;
    val cert = Thm.cterm_of thy;
    val certT = Thm.ctyp_of thy;

    val (type_insts, term_insts) = List.partition (is_tvar o fst) mixed_insts;
    val internal_insts = term_insts |> map_filter
      (fn (xi, Args.Term t) => SOME (xi, t)
        | (_, Args.Text _) => NONE
        | (xi, _) => error_var "Term argument expected for " xi);
    val external_insts = term_insts |> map_filter
      (fn (xi, Args.Text s) => SOME (xi, s) | _ => NONE);


    (* mixed type instantiations *)

    fun readT (xi, arg) =
      let
        val S = the_sort tvars xi;
        val T =
          (case arg of
            Args.Text s => ProofContext.read_typ ctxt s
          | Args.Typ T => T
          | _ => error_var "Type argument expected for " xi);
      in
        if Sign.of_sort thy (T, S) then ((xi, S), T)
        else error_var "Incompatible sort for typ instantiation of " xi
      end;

    val type_insts1 = map readT type_insts;
    val instT1 = TermSubst.instantiateT type_insts1;
    val vars1 = map (apsnd instT1) vars;


    (* internal term instantiations *)

    val instT2 = Envir.norm_type
      (#1 (fold (unify_vartypes thy vars1) internal_insts (Vartab.empty, 0)));
    val vars2 = map (apsnd instT2) vars1;
    val internal_insts2 = map (apsnd (map_types instT2)) internal_insts;
    val inst2 = instantiate internal_insts2;


    (* external term instantiations *)

    val (xs, strs) = split_list external_insts;
    val Ts = map (the_type vars2) xs;
    val (ts, inferred) =   (* FIXME polymorphic!? schematic vs. 'for' context!? *)
      ProofContext.read_termTs_schematic ctxt (K false) (K NONE) (K NONE) [] (strs ~~ Ts);

    val instT3 = Term.typ_subst_TVars inferred;
    val vars3 = map (apsnd instT3) vars2;
    val internal_insts3 = map (apsnd (map_types instT3)) internal_insts2;
    val external_insts3 = xs ~~ ts;
    val inst3 = instantiate external_insts3;


    (* results *)

    val type_insts3 = map (fn ((a, _), T) => (a, instT3 (instT2 T))) type_insts1;
    val term_insts3 = internal_insts3 @ external_insts3;

    val inst_tvars = map_filter (make_instT (instT3 o instT2 o instT1)) tvars;
    val inst_vars = map_filter (make_inst (inst3 o inst2)) vars3;
  in
    ((type_insts3, term_insts3),
      (map (pairself certT) inst_tvars, map (pairself cert) inst_vars))
  end;

fun read_instantiate ctxt mixed_insts thm =
  let
    val ctxt' = ctxt |> Variable.declare_thm thm
      |> fold (fn a => Variable.declare_internal (Logic.mk_type (TFree (a, [])))) (Drule.add_used thm []);  (* FIXME tmp *)
    val tvars = Drule.fold_terms Term.add_tvars thm [];
    val vars = Drule.fold_terms Term.add_vars thm [];
    val ((type_insts, term_insts), insts) = read_insts ctxt' (map snd mixed_insts) (tvars, vars);

    val _ = (*assign internalized values*)
      mixed_insts |> List.app (fn (arg, (xi, _)) =>
        if is_tvar xi then
          Args.assign (SOME (Args.Typ (the (AList.lookup (op =) type_insts xi)))) arg
        else
          Args.assign (SOME (Args.Term (the (AList.lookup (op =) term_insts xi)))) arg);
  in
    Drule.instantiate insts thm |> RuleCases.save thm
  end;

fun read_instantiate' ctxt (args, concl_args) thm =
  let
    fun zip_vars _ [] = []
      | zip_vars (_ :: xs) ((_, NONE) :: rest) = zip_vars xs rest
      | zip_vars ((x, _) :: xs) ((arg, SOME t) :: rest) = (arg, (x, t)) :: zip_vars xs rest
      | zip_vars [] _ = error "More instantiations than variables in theorem";
    val insts =
      zip_vars (rev (Term.add_vars (Thm.full_prop_of thm) [])) args @
      zip_vars (rev (Term.add_vars (Thm.concl_of thm) [])) concl_args;
  in read_instantiate ctxt insts thm end;

end;



(** attributes **)

(* where: named instantiation *)

local

val value =
  Args.internal_typ >> Args.Typ ||
  Args.internal_term >> Args.Term ||
  Args.name >> Args.Text;

val inst = Args.var -- (Args.$$$ "=" |-- Args.ahead -- value)
  >> (fn (xi, (a, v)) => (a, (xi, v)));

in

val where_att = Attrib.syntax (Args.and_list (Scan.lift inst) >> (fn args =>
  Thm.rule_attribute (fn context => read_instantiate (Context.proof_of context) args)));

end;


(* of: positional instantiation (terms only) *)

local

val value =
  Args.internal_term >> Args.Term ||
  Args.name >> Args.Text;

val inst = Args.ahead -- Args.maybe value;
val concl = Args.$$$ "concl" -- Args.colon;

val insts =
  Scan.repeat (Scan.unless concl inst) --
  Scan.optional (concl |-- Scan.repeat inst) [];

in

val of_att = Attrib.syntax (Scan.lift insts >> (fn args =>
  Thm.rule_attribute (fn context => read_instantiate' (Context.proof_of context) args)));

end;


(* setup *)

val _ = Context.add_setup (Attrib.add_attributes
 [("where", where_att, "named instantiation of theorem"),
  ("of", of_att, "positional instantiation of theorem")]);



(** methods **)

(* rule_tac etc. -- refer to dynamic goal state!! *)   (* FIXME cleanup!! *)

fun bires_inst_tac bires_flag ctxt insts thm =
  let
    val thy = ProofContext.theory_of ctxt;
    (* Separate type and term insts *)
    fun has_type_var ((x, _), _) = (case Symbol.explode x of
          "'"::cs => true | cs => false);
    val Tinsts = List.filter has_type_var insts;
    val tinsts = filter_out has_type_var insts;
    (* Tactic *)
    fun tac i st =
      let
        (* Preprocess state: extract environment information:
           - variables and their types
           - type variables and their sorts
           - parameters and their types *)
        val (types, sorts) = types_sorts st;
    (* Process type insts: Tinsts_env *)
    fun absent xi = error
          ("No such variable in theorem: " ^ Syntax.string_of_vname xi);
    val (rtypes, rsorts) = types_sorts thm;
    fun readT (xi, s) =
        let val S = case rsorts xi of SOME S => S | NONE => absent xi;
            val T = Sign.read_typ (thy, sorts) s;
            val U = TVar (xi, S);
        in if Sign.typ_instance thy (T, U) then (U, T)
           else error
             ("Instantiation of " ^ Syntax.string_of_vname xi ^ " fails")
        end;
    val Tinsts_env = map readT Tinsts;
    (* Preprocess rule: extract vars and their types, apply Tinsts *)
    fun get_typ xi =
      (case rtypes xi of
           SOME T => typ_subst_atomic Tinsts_env T
         | NONE => absent xi);
    val (xis, ss) = Library.split_list tinsts;
    val Ts = map get_typ xis;
        val (_, _, Bi, _) = dest_state(st,i)
        val params = Logic.strip_params Bi
                             (* params of subgoal i as string typ pairs *)
        val params = rev(Term.rename_wrt_term Bi params)
                           (* as they are printed: bound variables with *)
                           (* the same name are renamed during printing *)
        fun types' (a, ~1) = (case AList.lookup (op =) params a of
                NONE => types (a, ~1)
              | some => some)
          | types' xi = types xi;
        fun internal x = is_some (types' (x, ~1));
        val used = Drule.add_used thm (Drule.add_used st []);
        val (ts, envT) =
          ProofContext.read_termTs_schematic ctxt internal types' sorts used (ss ~~ Ts);
        val envT' = map (fn (ixn, T) =>
          (TVar (ixn, the (rsorts ixn)), T)) envT @ Tinsts_env;
        val cenv =
          map
            (fn (xi, t) =>
              pairself (Thm.cterm_of thy) (Var (xi, fastype_of t), t))
            (distinct
              (fn ((x1, t1), (x2, t2)) => x1 = x2 andalso t1 aconv t2)
              (xis ~~ ts));
        (* Lift and instantiate rule *)
        val {maxidx, ...} = rep_thm st;
        val paramTs = map #2 params
        and inc = maxidx+1
        fun liftvar (Var ((a,j), T)) =
              Var((a, j+inc), paramTs ---> Logic.incr_tvar inc T)
          | liftvar t = raise TERM("Variable expected", [t]);
        fun liftterm t = list_abs_free
              (params, Logic.incr_indexes(paramTs,inc) t)
        fun liftpair (cv,ct) =
              (cterm_fun liftvar cv, cterm_fun liftterm ct)
        val lifttvar = pairself (ctyp_of thy o Logic.incr_tvar inc);
        val rule = Drule.instantiate
              (map lifttvar envT', map liftpair cenv)
              (Thm.lift_rule (Thm.cprem_of st i) thm)
      in
        if i > nprems_of st then no_tac st
        else st |>
          compose_tac (bires_flag, rule, nprems_of thm) i
      end
           handle TERM (msg,_)   => (warning msg; no_tac st)
                | THM  (msg,_,_) => (warning msg; no_tac st);
  in tac end;

local

fun gen_inst _ tac _ (quant, ([], thms)) =
      Method.METHOD (fn facts => quant (Method.insert_tac facts THEN' tac thms))
  | gen_inst inst_tac _ ctxt (quant, (insts, [thm])) =
      Method.METHOD (fn facts =>
        quant (Method.insert_tac facts THEN' inst_tac ctxt insts thm))
  | gen_inst _ _ _ _ = error "Cannot have instantiations with multiple rules";

in

val res_inst_meth = gen_inst (bires_inst_tac false) Tactic.resolve_tac;

val eres_inst_meth = gen_inst (bires_inst_tac true) Tactic.eresolve_tac;

val cut_inst_meth =
  gen_inst
    (fn ctxt => fn insts => bires_inst_tac false ctxt insts o Tactic.make_elim_preserve)
    Tactic.cut_rules_tac;

val dres_inst_meth =
  gen_inst
    (fn ctxt => fn insts => bires_inst_tac true ctxt insts o Tactic.make_elim_preserve)
    Tactic.dresolve_tac;

val forw_inst_meth =
  gen_inst
    (fn ctxt => fn insts => fn rule =>
       bires_inst_tac false ctxt insts (Tactic.make_elim_preserve rule) THEN'
       assume_tac)
    Tactic.forward_tac;

fun subgoal_tac ctxt sprop =
  DETERM o bires_inst_tac false ctxt [(("psi", 0), sprop)] cut_rl;

fun subgoals_tac ctxt sprops = EVERY' (map (subgoal_tac ctxt) sprops);

fun thin_tac ctxt s =
  bires_inst_tac true ctxt [(("V", 0), s)] thin_rl;


(* method syntax *)

val insts =
  Scan.optional
    (Args.enum1 "and" (Scan.lift (Args.name -- (Args.$$$ "=" |-- Args.!!! Args.name))) --|
      Scan.lift (Args.$$$ "in")) [] -- Attrib.thms;

fun inst_args f src ctxt =
  f ctxt (#2 (Method.syntax (Args.goal_spec HEADGOAL -- insts) src ctxt));

val insts_var =
  Scan.optional
    (Args.enum1 "and" (Scan.lift (Args.var -- (Args.$$$ "=" |-- Args.!!! Args.name))) --|
      Scan.lift (Args.$$$ "in")) [] -- Attrib.thms;

fun inst_args_var f src ctxt =
  f ctxt (#2 (Method.syntax (Args.goal_spec HEADGOAL -- insts_var) src ctxt));


(* setup *)

val _ = Context.add_setup (Method.add_methods
 [("rule_tac", inst_args_var res_inst_meth,
    "apply rule (dynamic instantiation)"),
  ("erule_tac", inst_args_var eres_inst_meth,
    "apply rule in elimination manner (dynamic instantiation)"),
  ("drule_tac", inst_args_var dres_inst_meth,
    "apply rule in destruct manner (dynamic instantiation)"),
  ("frule_tac", inst_args_var forw_inst_meth,
    "apply rule in forward manner (dynamic instantiation)"),
  ("cut_tac", inst_args_var cut_inst_meth,
    "cut rule (dynamic instantiation)"),
  ("subgoal_tac", Method.goal_args_ctxt (Scan.repeat1 Args.name) subgoals_tac,
    "insert subgoal (dynamic instantiation)"),
  ("thin_tac", Method.goal_args_ctxt Args.name thin_tac,
    "remove premise (dynamic instantiation)")]);

end;

end;