src/HOL/Tools/inductive_realizer.ML
 author wenzelm Sat, 29 Mar 2008 19:14:03 +0100 changeset 26481 92e901171cc8 parent 26477 ecf06644f6cb child 26535 66bca8a4079c permissions -rw-r--r--
simplified PureThy.store_thm;
```
(*  Title:      HOL/Tools/inductive_realizer.ML
ID:         \$Id\$
Author:     Stefan Berghofer, TU Muenchen

Porgram extraction from proofs involving inductive predicates:
Realizers for induction and elimination rules
*)

signature INDUCTIVE_REALIZER =
sig
val add_ind_realizers: string -> string list -> theory -> theory
val setup: theory -> theory
end;

structure InductiveRealizer : INDUCTIVE_REALIZER =
struct

(* FIXME: LocalTheory.note should return theorems with proper names! *)
fun name_of_thm thm =
(case Symtab.dest (Proofterm.thms_of_proof' (proof_of thm) Symtab.empty) of
[(name, _)] => name
| _ => error ("name_of_thm: bad proof of theorem\n" ^ string_of_thm thm));

val all_simps = map (symmetric o mk_meta_eq) (thms "HOL.all_simps");

fun prf_of thm =
let val {thy, prop, der = (_, prf), ...} = rep_thm thm
in Reconstruct.expand_proof thy [("", NONE)] (Reconstruct.reconstruct_proof thy prop prf) end; (* FIXME *)

fun forall_intr_prf (t, prf) =
let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end;

fun forall_intr_term (t, u) =
let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
in all T \$ Abs (a, T, abstract_over (t, u)) end;

fun subsets [] = [[]]
| subsets (x::xs) =
let val ys = subsets xs
in ys @ map (cons x) ys end;

val pred_of = fst o dest_Const o head_of;

fun strip_all' used names (Const ("all", _) \$ Abs (s, T, t)) =
let val (s', names') = (case names of
[] => (Name.variant used s, [])
| name :: names' => (name, names'))
in strip_all' (s'::used) names' (subst_bound (Free (s', T), t)) end
| strip_all' used names ((t as Const ("==>", _) \$ P) \$ Q) =
t \$ strip_all' used names Q
| strip_all' _ _ t = t;

fun strip_all t = strip_all' (add_term_free_names (t, [])) [] t;

fun strip_one name (Const ("all", _) \$ Abs (s, T, Const ("==>", _) \$ P \$ Q)) =
(subst_bound (Free (name, T), P), subst_bound (Free (name, T), Q))
| strip_one _ (Const ("==>", _) \$ P \$ Q) = (P, Q);

fun relevant_vars prop = foldr (fn
(Var ((a, i), T), vs) => (case strip_type T of
(_, Type (s, _)) => if s mem ["bool"] then (a, T) :: vs else vs
| _ => vs)
| (_, vs) => vs) [] (term_vars prop);

fun dt_of_intrs thy vs nparms intrs =
let
val iTs = term_tvars (prop_of (hd intrs));
val Tvs = map TVar iTs;
val (Const (s, _), ts) = strip_comb (HOLogic.dest_Trueprop
(Logic.strip_imp_concl (prop_of (hd intrs))));
val params = map dest_Var (Library.take (nparms, ts));
val tname = space_implode "_" (Sign.base_name s ^ "T" :: vs);
fun constr_of_intr intr = (Sign.base_name (name_of_thm intr),
map (Logic.unvarifyT o snd) (rev (Term.add_vars (prop_of intr) []) \\ params) @
filter_out (equal Extraction.nullT) (map
(Logic.unvarifyT o Extraction.etype_of thy vs []) (prems_of intr)),
NoSyn);
in (map (fn a => "'" ^ a) vs @ map (fst o fst) iTs, tname, NoSyn,
map constr_of_intr intrs)
end;

fun mk_rlz T = Const ("realizes", [T, HOLogic.boolT] ---> HOLogic.boolT);

(** turn "P" into "%r x. realizes r (P x)" **)

fun gen_rvar vs (t as Var ((a, 0), T)) =
if body_type T <> HOLogic.boolT then t else
let
val U = TVar (("'" ^ a, 0), HOLogic.typeS)
val Ts = binder_types T;
val i = length Ts;
val xs = map (pair "x") Ts;
val u = list_comb (t, map Bound (i - 1 downto 0))
in
if a mem vs then
list_abs (("r", U) :: xs, mk_rlz U \$ Bound i \$ u)
else list_abs (xs, mk_rlz Extraction.nullT \$ Extraction.nullt \$ u)
end
| gen_rvar _ t = t;

fun mk_realizes_eqn n vs nparms intrs =
let
val concl = HOLogic.dest_Trueprop (concl_of (hd intrs));
val iTs = term_tvars concl;
val Tvs = map TVar iTs;
val (h as Const (s, T), us) = strip_comb concl;
val params = List.take (us, nparms);
val elTs = List.drop (binder_types T, nparms);
val predT = elTs ---> HOLogic.boolT;
val used = map (fst o fst o dest_Var) params;
val xs = map (Var o apfst (rpair 0))
(Name.variant_list used (replicate (length elTs) "x") ~~ elTs);
val rT = if n then Extraction.nullT
else Type (space_implode "_" (s ^ "T" :: vs),
map (fn a => TVar (("'" ^ a, 0), HOLogic.typeS)) vs @ Tvs);
val r = if n then Extraction.nullt else Var ((Sign.base_name s, 0), rT);
val S = list_comb (h, params @ xs);
val rvs = relevant_vars S;
val vs' = map fst rvs \\ vs;
val rname = space_implode "_" (s ^ "R" :: vs);

fun mk_Tprem n v =
let val T = (the o AList.lookup (op =) rvs) v
in (Const ("typeof", T --> Type ("Type", [])) \$ Var ((v, 0), T),
Extraction.mk_typ (if n then Extraction.nullT
else TVar (("'" ^ v, 0), HOLogic.typeS)))
end;

val prems = map (mk_Tprem true) vs' @ map (mk_Tprem false) vs;
val ts = map (gen_rvar vs) params;
val argTs = map fastype_of ts;

in ((prems, (Const ("typeof", HOLogic.boolT --> Type ("Type", [])) \$ S,
Extraction.mk_typ rT)),
(prems, (mk_rlz rT \$ r \$ S,
if n then list_comb (Const (rname, argTs ---> predT), ts @ xs)
else list_comb (Const (rname, argTs @ [rT] ---> predT), ts @ [r] @ xs))))
end;

fun fun_of_prem thy rsets vs params rule ivs intr =
let
val ctxt = ProofContext.init thy
val args = map (Free o apfst fst o dest_Var) ivs;
val args' = map (Free o apfst fst)
(Term.add_vars (prop_of intr) [] \\ params);
val rule' = strip_all rule;
val conclT = Extraction.etype_of thy vs [] (Logic.strip_imp_concl rule');
val used = map (fst o dest_Free) args;

fun is_rec t = not (null (term_consts t inter rsets));

fun is_meta (Const ("all", _) \$ Abs (s, _, P)) = is_meta P
| is_meta (Const ("==>", _) \$ _ \$ Q) = is_meta Q
| is_meta (Const ("Trueprop", _) \$ t) = (case head_of t of
Const (s, _) => can (InductivePackage.the_inductive ctxt) s
| _ => true)
| is_meta _ = false;

fun fun_of ts rts args used (prem :: prems) =
let
val T = Extraction.etype_of thy vs [] prem;
val [x, r] = Name.variant_list used ["x", "r"]
in if T = Extraction.nullT
then fun_of ts rts args used prems
else if is_rec prem then
if is_meta prem then
let
val prem' :: prems' = prems;
val U = Extraction.etype_of thy vs [] prem';
in if U = Extraction.nullT
then fun_of (Free (x, T) :: ts)
(Free (r, binder_types T ---> HOLogic.unitT) :: rts)
(Free (x, T) :: args) (x :: r :: used) prems'
else fun_of (Free (x, T) :: ts) (Free (r, U) :: rts)
(Free (r, U) :: Free (x, T) :: args) (x :: r :: used) prems'
end
else (case strip_type T of
(Ts, Type ("*", [T1, T2])) =>
let
val fx = Free (x, Ts ---> T1);
val fr = Free (r, Ts ---> T2);
val bs = map Bound (length Ts - 1 downto 0);
val t = list_abs (map (pair "z") Ts,
HOLogic.mk_prod (list_comb (fx, bs), list_comb (fr, bs)))
in fun_of (fx :: ts) (fr :: rts) (t::args)
(x :: r :: used) prems
end
| (Ts, U) => fun_of (Free (x, T) :: ts)
(Free (r, binder_types T ---> HOLogic.unitT) :: rts)
(Free (x, T) :: args) (x :: r :: used) prems)
else fun_of (Free (x, T) :: ts) rts (Free (x, T) :: args)
(x :: used) prems
end
| fun_of ts rts args used [] =
let val xs = rev (rts @ ts)
in if conclT = Extraction.nullT
then list_abs_free (map dest_Free xs, HOLogic.unit)
else list_abs_free (map dest_Free xs, list_comb
(Free ("r" ^ Sign.base_name (name_of_thm intr),
map fastype_of (rev args) ---> conclT), rev args))
end

in fun_of args' [] (rev args) used (Logic.strip_imp_prems rule') end;

fun indrule_realizer thy induct raw_induct rsets params vs rec_names rss intrs dummies =
let
val concls = HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of raw_induct));
val premss = List.mapPartial (fn (s, rs) => if s mem rsets then
SOME (rs, map (fn (_, r) => List.nth (prems_of raw_induct,
find_index_eq (prop_of r) (map prop_of intrs))) rs) else NONE) rss;
val fs = maps (fn ((intrs, prems), dummy) =>
let
val fs = map (fn (rule, (ivs, intr)) =>
fun_of_prem thy rsets vs params rule ivs intr) (prems ~~ intrs)
in if dummy then Const ("arbitrary",
HOLogic.unitT --> body_type (fastype_of (hd fs))) :: fs
else fs
end) (premss ~~ dummies);
val frees = fold Term.add_frees fs [];
val Ts = map fastype_of fs;
fun name_of_fn intr = "r" ^ Sign.base_name (name_of_thm intr)
in
fst (fold_map (fn concl => fn names =>
let val T = Extraction.etype_of thy vs [] concl
in if T = Extraction.nullT then (Extraction.nullt, names) else
let
val Type ("fun", [U, _]) = T;
val a :: names' = names
in (list_abs_free (("x", U) :: List.mapPartial (fn intr =>
Option.map (pair (name_of_fn intr))
(AList.lookup (op =) frees (name_of_fn intr))) intrs,
list_comb (Const (a, Ts ---> T), fs) \$ Free ("x", U)), names')
end
end) concls rec_names)
end;

fun add_dummy name dname (x as (_, (vs, s, mfx, cs))) =
if (name: string) = s then (true, (vs, s, mfx, (dname, [HOLogic.unitT], NoSyn) :: cs))
else x;

fun add_dummies f [] _ thy =
(([], NONE), thy)
| add_dummies f dts used thy =
thy
|> f (map snd dts)
|-> (fn dtinfo => pair ((map fst dts), SOME dtinfo))
handle DatatypeAux.Datatype_Empty name' =>
let
val name = Sign.base_name name';
val dname = Name.variant used "Dummy"
in
thy
end;

fun mk_realizer thy vs (name, rule, rrule, rlz, rt) =
let
val rvs = map fst (relevant_vars (prop_of rule));
val xs = rev (Term.add_vars (prop_of rule) []);
val vs1 = map Var (filter_out (fn ((a, _), _) => a mem rvs) xs);
val rlzvs = rev (Term.add_vars (prop_of rrule) []);
val vs2 = map (fn (ixn, _) => Var (ixn, (the o AList.lookup (op =) rlzvs) ixn)) xs;
val rs = map Var (subtract (op = o pairself fst) xs rlzvs);
val rlz' = foldr forall_intr_term (prop_of rrule) (vs2 @ rs);
val rlz'' = foldr forall_intr_term rlz vs2
in (name, (vs,
if rt = Extraction.nullt then rt else
foldr (uncurry lambda) rt vs1,
ProofRewriteRules.un_hhf_proof rlz' rlz''
(foldr forall_intr_prf (prf_of rrule) (vs2 @ rs))))
end;

fun rename tab = map (fn x => the_default x (AList.lookup op = tab x));

fun add_ind_realizer rsets intrs induct raw_induct elims (thy, vs) =
let
val qualifier = NameSpace.qualifier (name_of_thm induct);
val inducts = PureThy.get_thms thy (NameSpace.qualified qualifier "inducts");
val iTs = term_tvars (prop_of (hd intrs));
val ar = length vs + length iTs;
val params = InductivePackage.params_of raw_induct;
val arities = InductivePackage.arities_of raw_induct;
val nparms = length params;
val params' = map dest_Var params;
val rss = InductivePackage.partition_rules raw_induct intrs;
val rss' = map (fn (((s, rs), (_, arity)), elim) =>
(s, (InductivePackage.infer_intro_vars elim arity rs ~~ rs)))
val (prfx, _) = split_last (NameSpace.explode (fst (hd rss)));
val tnames = map (fn s => space_implode "_" (s ^ "T" :: vs)) rsets;

val thy1 = thy |>
Sign.root_path |>
val (ty_eqs, rlz_eqs) = split_list
(map (fn (s, rs) => mk_realizes_eqn (not (s mem rsets)) vs nparms rs) rss);

val thy1' = thy1 |>
Theory.copy |>
Sign.add_types (map (fn s => (Sign.base_name s, ar, NoSyn)) tnames) |>
fold (fn s => AxClass.axiomatize_arity
(s, replicate ar HOLogic.typeS, HOLogic.typeS)) tnames |>
val dts = List.mapPartial (fn (s, rs) => if s mem rsets then
SOME (dt_of_intrs thy1' vs nparms rs) else NONE) rss;

(** datatype representing computational content of inductive set **)

val ((dummies, dt_info), thy2) =
thy1
(DatatypePackage.add_datatype_i false false (map #2 dts))
(map (pair false) dts) []
fun get f = (these oo Option.map) f;
val rec_names = distinct (op =) (map (fst o dest_Const o head_of o fst o
HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of) (get #rec_thms dt_info));
val (_, constrss) = foldl_map (fn ((recs, dummies), (s, rs)) =>
if s mem rsets then
let
val (d :: dummies') = dummies;
val (recs1, recs2) = chop (length rs) (if d then tl recs else recs)
in ((recs2, dummies'), map (head_of o hd o rev o snd o strip_comb o
fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of) recs1)
end
else ((recs, dummies), replicate (length rs) Extraction.nullt))
val rintrs = map (fn (intr, c) => Envir.eta_contract
(Extraction.realizes_of thy2 vs
(if c = Extraction.nullt then c else list_comb (c, map Var (rev
(Term.add_vars (prop_of intr) []) \\ params'))) (prop_of intr)))
val (rlzpreds, rlzpreds') = split_list
(distinct (op = o pairself (#1 o #1)) (map (fn rintr =>
let
val Const (s, T) = head_of (HOLogic.dest_Trueprop
(Logic.strip_assums_concl rintr));
val s' = Sign.base_name s;
val T' = Logic.unvarifyT T
in (((s', T'), NoSyn),
(Const (s, T'), Free (s', T')))
end) rintrs));
val rlzparams = map (fn Var ((s, _), T) => (s, Logic.unvarifyT T))
(List.take (snd (strip_comb
(HOLogic.dest_Trueprop (Logic.strip_assums_concl (hd rintrs)))), nparms));

(** realizability predicate **)

val (ind_info, thy3') = thy2 |>
{quiet_mode = false, verbose = false, kind = Thm.theoremK, alt_name = "",
coind = false, no_elim = false, no_ind = false}
rlzpreds rlzparams (map (fn (rintr, intr) =>
((Sign.base_name (name_of_thm intr), []),
subst_atomic rlzpreds' (Logic.unvarify rintr)))
(rintrs ~~ maps snd rss)) [] ||>
Sign.absolute_path;
val thy3 = PureThy.hide_thms false
(map name_of_thm (#intrs ind_info)) thy3';

(** realizer for induction rule **)

val Ps = List.mapPartial (fn _ \$ M \$ P => if pred_of M mem rsets then
SOME (fst (fst (dest_Var (head_of P)))) else NONE)
(HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of raw_induct)));

let
val vs' = rename (map (pairself (fst o fst o dest_Var))
(params ~~ List.take (snd (strip_comb (HOLogic.dest_Trueprop
(hd (prems_of (hd inducts))))), nparms))) vs;
val rs = indrule_realizer thy induct raw_induct rsets params'
(vs' @ Ps) rec_names rss' intrs dummies;
val rlzs = map (fn (r, ind) => Extraction.realizes_of thy (vs' @ Ps) r
(prop_of ind)) (rs ~~ inducts);
val used = foldr add_term_free_names [] rlzs;
val rnames = Name.variant_list used (replicate (length inducts) "r");
val rnames' = Name.variant_list
(used @ rnames) (replicate (length intrs) "s");
val rlzs' as (prems, _, _) :: _ = map (fn (rlz, name) =>
let
val (P, Q) = strip_one name (Logic.unvarify rlz);
val Q' = strip_all' [] rnames' Q
in
(Logic.strip_imp_prems Q', P, Logic.strip_imp_concl Q')
end) (rlzs ~~ rnames);
val concl = HOLogic.mk_Trueprop (foldr1 HOLogic.mk_conj (map
(fn (_, _ \$ P, _ \$ Q) => HOLogic.mk_imp (P, Q)) rlzs'));
val rews = map mk_meta_eq
(fst_conv :: snd_conv :: get #rec_thms dt_info);
val thm = Goal.prove_global thy [] prems concl (fn prems => EVERY
[rtac (#raw_induct ind_info) 1,
rewrite_goals_tac rews,
REPEAT ((resolve_tac prems THEN_ALL_NEW EVERY'
[K (rewrite_goals_tac rews), ObjectLogic.atomize_prems_tac,
DEPTH_SOLVE_1 o FIRST' [atac, etac allE, etac impE]]) 1)]);
val (thm', thy') = PureThy.store_thm (space_implode "_"
(NameSpace.qualified qualifier "induct" :: vs' @ Ps @ ["correctness"]), thm) thy;
val thms = map (fn th => zero_var_indexes (rotate_prems ~1 (th RS mp)))
(DatatypeAux.split_conj_thm thm');
[((space_implode "_"
(NameSpace.qualified qualifier "inducts" :: vs' @ Ps @
["correctness"]), thms), [])] thy';
val realizers = inducts ~~ thms' ~~ rlzs ~~ rs;
in
(map (fn (((ind, corr), rlz), r) =>
mk_realizer thy' (vs' @ Ps) (Thm.get_name ind, ind, corr, rlz, r))
realizers @ (case realizers of
[(((ind, corr), rlz), r)] =>
[mk_realizer thy' (vs' @ Ps) (NameSpace.qualified qualifier "induct",
ind, corr, rlz, r)]
| _ => [])) thy''
end;

(** realizer for elimination rules **)

val case_names = map (fst o dest_Const o head_of o fst o HOLogic.dest_eq o
HOLogic.dest_Trueprop o prop_of o hd) (get #case_thms dt_info);

(((((elim, elimR), intrs), case_thms), case_name), dummy) thy =
let
val (prem :: prems) = prems_of elim;
fun reorder1 (p, (_, intr)) =
Library.foldl (fn (t, ((s, _), T)) => all T \$ lambda (Free (s, T)) t)
(strip_all p, Term.add_vars (prop_of intr) [] \\ params');
fun reorder2 ((ivs, intr), i) =
let val fs = Term.add_vars (prop_of intr) [] \\ params'
in Library.foldl (fn (t, x) => lambda (Var x) t)
(list_comb (Bound (i + length ivs), ivs), fs)
end;
val p = Logic.list_implies
(map reorder1 (prems ~~ intrs) @ [prem], concl_of elim);
val T' = Extraction.etype_of thy (vs @ Ps) [] p;
val T = if dummy then (HOLogic.unitT --> body_type T') --> T' else T';
val Ts = map (Extraction.etype_of thy (vs @ Ps) []) (prems_of elim);
val r = if null Ps then Extraction.nullt
else list_abs (map (pair "x") Ts, list_comb (Const (case_name, T),
(if dummy then
[Abs ("x", HOLogic.unitT, Const ("arbitrary", body_type T))]
else []) @
map reorder2 (intrs ~~ (length prems - 1 downto 0)) @
[Bound (length prems)]));
val rlz = Extraction.realizes_of thy (vs @ Ps) r (prop_of elim);
val rlz' = strip_all (Logic.unvarify rlz);
val rews = map mk_meta_eq case_thms;
val thm = Goal.prove_global thy []
(Logic.strip_imp_prems rlz') (Logic.strip_imp_concl rlz') (fn prems => EVERY
[cut_facts_tac [hd prems] 1,
etac elimR 1,
ALLGOALS (asm_simp_tac HOL_basic_ss),
rewrite_goals_tac rews,
REPEAT ((resolve_tac prems THEN_ALL_NEW (ObjectLogic.atomize_prems_tac THEN'
DEPTH_SOLVE_1 o FIRST' [atac, etac allE, etac impE])) 1)]);
val (thm', thy') = PureThy.store_thm (space_implode "_"
(name_of_thm elim :: vs @ Ps @ ["correctness"]), thm) thy
in
[mk_realizer thy' (vs @ Ps) (name_of_thm elim, elim, thm', rlz, r)] thy'
end;

(** add realizers to theory **)

val thy4 = Library.foldl add_ind_realizer (thy3, subsets Ps);
(map (mk_realizer thy4 vs) (map (fn (((rule, rrule), rlz), c) =>
(name_of_thm rule, rule, rrule, rlz,
list_comb (c, map Var (rev (Term.add_vars (prop_of rule) []) \\ params'))))
(List.concat (map snd rss) ~~ #intrs ind_info ~~ rintrs ~~
val elimps = List.mapPartial (fn ((s, intrs), p) =>
if s mem rsets then SOME (p, intrs) else NONE)
(rss' ~~ (elims ~~ #elims ind_info));
val thy6 = Library.foldl (fn (thy, p as (((((elim, _), _), _), _), _)) => thy |>
(HOLogic.dest_Trueprop (concl_of elim))))] p) (thy5,
elimps ~~ get #case_thms dt_info ~~ case_names ~~ dummies)

in Sign.restore_naming thy thy6 end;

fun add_ind_realizers name rsets thy =
let
val (_, {intrs, induct, raw_induct, elims, ...}) =
InductivePackage.the_inductive (ProofContext.init thy) name;
val vss = sort (int_ord o pairself length)
(subsets (map fst (relevant_vars (concl_of (hd intrs)))))
in
Library.foldl (add_ind_realizer rsets intrs induct raw_induct elims) (thy, vss)
end

fun rlz_attrib arg = Thm.declaration_attribute (fn thm => Context.mapping
let
fun err () = error "ind_realizer: bad rule";
val sets =
(case HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of thm)) of
[_] => [pred_of (HOLogic.dest_Trueprop (hd (prems_of thm)))]
| xs => map (pred_of o fst o HOLogic.dest_imp) xs)
handle TERM _ => err () | Empty => err ();
in
(case arg of
NONE => sets | SOME NONE => []
| SOME (SOME sets') => sets \\ sets')
end I);

val ind_realizer = Attrib.syntax
((Scan.option (Scan.lift (Args.\$\$\$ "irrelevant") |--
Scan.option (Scan.lift (Args.colon) |--
Scan.repeat1 Args.const))) >> rlz_attrib);

val setup =