(* Title: Pure/Isar/calculation.ML
Author: Markus Wenzel, TU Muenchen
Generic calculational proofs.
*)
signature CALCULATION =
sig
val print_rules: Proof.context -> unit
val check: Proof.state -> thm list option
val trans_add: attribute
val trans_del: attribute
val sym_add: attribute
val sym_del: attribute
val symmetric: attribute
val also: thm list option -> bool -> Proof.state -> Proof.state Seq.result Seq.seq
val also_cmd: (Facts.ref * Token.src list) list option ->
bool -> Proof.state -> Proof.state Seq.result Seq.seq
val finally: thm list option -> bool -> Proof.state -> Proof.state Seq.result Seq.seq
val finally_cmd: (Facts.ref * Token.src list) list option -> bool ->
Proof.state -> Proof.state Seq.result Seq.seq
val moreover: bool -> Proof.state -> Proof.state
val ultimately: bool -> Proof.state -> Proof.state
end;
structure Calculation: CALCULATION =
struct
(** calculation data **)
type calculation = {result: thm list, level: int, serial: serial, pos: Position.T};
structure Data = Generic_Data
(
type T = (thm Item_Net.T * thm Item_Net.T) * calculation option;
val empty = ((Thm.item_net_elim, Thm.item_net), NONE);
fun merge (((trans1, sym1), _), ((trans2, sym2), _)) =
((Item_Net.merge (trans1, trans2), Item_Net.merge (sym1, sym2)), NONE);
);
val get_rules = #1 o Data.get o Context.Proof;
val get_trans_rules = Item_Net.content o #1 o get_rules;
val get_sym_rules = Item_Net.content o #2 o get_rules;
val get_calculation = #2 o Data.get o Context.Proof;
fun print_rules ctxt =
let val pretty_thm = Thm.pretty_thm_item ctxt in
[Pretty.big_list "transitivity rules:" (map pretty_thm (get_trans_rules ctxt)),
Pretty.big_list "symmetry rules:" (map pretty_thm (get_sym_rules ctxt))]
end |> Pretty.writeln_chunks;
(* access calculation *)
fun check_calculation state =
(case get_calculation (Proof.context_of state) of
NONE => NONE
| SOME calculation =>
if #level calculation = Proof.level state then SOME calculation else NONE);
val check = Option.map #result o check_calculation;
val calculationN = "calculation";
fun update_calculation calc state =
let
fun report def serial pos =
Context_Position.report (Proof.context_of state)
(Position.thread_data ())
(Position.make_entity_markup def serial calculationN ("", pos));
val calculation =
(case calc of
NONE => NONE
| SOME result =>
(case check_calculation state of
NONE =>
let
val level = Proof.level state;
val serial = serial ();
val pos = Position.thread_data ();
val _ = report {def = true} serial pos;
in SOME {result = result, level = level, serial = serial, pos = pos} end
| SOME {level, serial, pos, ...} =>
(report {def = false} serial pos;
SOME {result = result, level = level, serial = serial, pos = pos})));
in
state
|> (Proof.map_context o Context.proof_map o Data.map o apsnd) (K calculation)
|> Proof.map_context (Proof_Context.put_thms false (calculationN, calc))
end;
(** attributes **)
(* add/del rules *)
val trans_add =
Thm.declaration_attribute (Data.map o apfst o apfst o Item_Net.update o Thm.trim_context);
val trans_del =
Thm.declaration_attribute (Data.map o apfst o apfst o Item_Net.remove);
val sym_add =
Thm.declaration_attribute (fn th =>
(Data.map o apfst o apsnd) (Item_Net.update (Thm.trim_context th)) #>
Thm.attribute_declaration (Context_Rules.elim_query NONE) th);
val sym_del =
Thm.declaration_attribute (fn th =>
(Data.map o apfst o apsnd) (Item_Net.remove th) #>
Thm.attribute_declaration Context_Rules.rule_del th);
(* symmetric *)
val symmetric =
Thm.rule_attribute [] (fn context => fn th =>
let val ctxt = Context.proof_of context in
(case Seq.chop 2 (Drule.multi_resolves (SOME ctxt) [th] (get_sym_rules ctxt)) of
([th'], _) => Drule.zero_var_indexes th'
| ([], _) => raise THM ("symmetric: no unifiers", 1, [th])
| _ => raise THM ("symmetric: multiple unifiers", 1, [th]))
end);
(* concrete syntax *)
val _ = Theory.setup
(Attrib.setup \<^binding>\<open>trans\<close> (Attrib.add_del trans_add trans_del)
"declaration of transitivity rule" #>
Attrib.setup \<^binding>\<open>sym\<close> (Attrib.add_del sym_add sym_del)
"declaration of symmetry rule" #>
Attrib.setup \<^binding>\<open>symmetric\<close> (Scan.succeed symmetric)
"resolution with symmetry rule" #>
Global_Theory.add_thms
[((Binding.empty, transitive_thm), [trans_add]),
((Binding.empty, symmetric_thm), [sym_add])] #> snd);
(** proof commands **)
fun assert_sane final =
if final then Proof.assert_forward
else
Proof.assert_forward_or_chain #>
tap (fn state =>
if can Proof.assert_chain state then
Context_Position.report (Proof.context_of state) (Position.thread_data ()) Markup.improper
else ());
fun maintain_calculation int final calc state =
let
val state' = state
|> update_calculation (SOME calc)
|> Proof.improper_reset_facts;
val ctxt' = Proof.context_of state';
val _ =
if int then
Proof_Context.pretty_fact ctxt'
(Proof_Context.full_name ctxt' (Binding.name calculationN), calc)
|> Pretty.string_of |> writeln
else ();
in state' |> final ? (update_calculation NONE #> Proof.chain_facts calc) end;
(* also and finally *)
fun calculate prep_rules final raw_rules int state =
let
val ctxt = Proof.context_of state;
val pretty_thm = Thm.pretty_thm ctxt;
val pretty_thm_item = Thm.pretty_thm_item ctxt;
val strip_assums_concl = Logic.strip_assums_concl o Thm.prop_of;
val eq_prop = op aconv o apply2 (Envir.beta_eta_contract o strip_assums_concl);
fun check_projection ths th =
(case find_index (curry eq_prop th) ths of
~1 => Seq.Result [th]
| i =>
Seq.Error (fn () =>
(Pretty.string_of o Pretty.chunks)
[Pretty.block [Pretty.str "Vacuous calculation result:", Pretty.brk 1, pretty_thm th],
(Pretty.block o Pretty.fbreaks)
(Pretty.str ("derived as projection (" ^ string_of_int (i + 1) ^ ") from:") ::
map pretty_thm_item ths)]));
val opt_rules = Option.map (prep_rules ctxt) raw_rules;
fun combine ths =
Seq.append
((case opt_rules of
SOME rules => rules
| NONE =>
(case ths of
[] => Item_Net.content (#1 (get_rules ctxt))
| th :: _ => Item_Net.retrieve (#1 (get_rules ctxt)) (strip_assums_concl th)))
|> Seq.of_list |> Seq.maps (Drule.multi_resolve (SOME ctxt) ths)
|> Seq.map (check_projection ths))
(Seq.single (Seq.Error (fn () =>
(Pretty.string_of o Pretty.block o Pretty.fbreaks)
(Pretty.str "No matching trans rules for calculation:" ::
map pretty_thm_item ths))));
val facts = Proof.the_facts (assert_sane final state);
val (initial, calculations) =
(case check state of
NONE => (true, Seq.single (Seq.Result facts))
| SOME calc => (false, combine (calc @ facts)));
val _ = initial andalso final andalso error "No calculation yet";
val _ = initial andalso is_some opt_rules andalso
error "Initial calculation -- no rules to be given";
in
calculations |> Seq.map_result (fn calc => maintain_calculation int final calc state)
end;
val also = calculate (K I) false;
val also_cmd = calculate Attrib.eval_thms false;
val finally = calculate (K I) true;
val finally_cmd = calculate Attrib.eval_thms true;
(* moreover and ultimately *)
fun collect final int state =
let
val facts = Proof.the_facts (assert_sane final state);
val (initial, thms) =
(case check state of
NONE => (true, [])
| SOME thms => (false, thms));
val calc = thms @ facts;
val _ = initial andalso final andalso error "No calculation yet";
in maintain_calculation int final calc state end;
val moreover = collect false;
val ultimately = collect true;
end;