1 (* Title: HOL/Tools/datatype_realizer.ML |
|
2 Author: Stefan Berghofer, TU Muenchen |
|
3 |
|
4 Porgram extraction from proofs involving datatypes: |
|
5 Realizers for induction and case analysis |
|
6 *) |
|
7 |
|
8 signature DATATYPE_REALIZER = |
|
9 sig |
|
10 val add_dt_realizers: string list -> theory -> theory |
|
11 val setup: theory -> theory |
|
12 end; |
|
13 |
|
14 structure DatatypeRealizer : DATATYPE_REALIZER = |
|
15 struct |
|
16 |
|
17 open DatatypeAux; |
|
18 |
|
19 fun subsets i j = if i <= j then |
|
20 let val is = subsets (i+1) j |
|
21 in map (fn ks => i::ks) is @ is end |
|
22 else [[]]; |
|
23 |
|
24 fun forall_intr_prf (t, prf) = |
|
25 let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p) |
|
26 in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end; |
|
27 |
|
28 fun prf_of thm = |
|
29 Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm) (Thm.proof_of thm); |
|
30 |
|
31 fun prf_subst_vars inst = |
|
32 Proofterm.map_proof_terms (subst_vars ([], inst)) I; |
|
33 |
|
34 fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT; |
|
35 |
|
36 fun tname_of (Type (s, _)) = s |
|
37 | tname_of _ = ""; |
|
38 |
|
39 fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT); |
|
40 |
|
41 fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : datatype_info) is thy = |
|
42 let |
|
43 val recTs = get_rec_types descr sorts; |
|
44 val pnames = if length descr = 1 then ["P"] |
|
45 else map (fn i => "P" ^ string_of_int i) (1 upto length descr); |
|
46 |
|
47 val rec_result_Ts = map (fn ((i, _), P) => |
|
48 if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT) |
|
49 (descr ~~ pnames); |
|
50 |
|
51 fun make_pred i T U r x = |
|
52 if i mem is then |
|
53 Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x |
|
54 else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x; |
|
55 |
|
56 fun mk_all i s T t = |
|
57 if i mem is then list_all_free ([(s, T)], t) else t; |
|
58 |
|
59 val (prems, rec_fns) = split_list (flat (fst (fold_map |
|
60 (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j => |
|
61 let |
|
62 val Ts = map (typ_of_dtyp descr sorts) cargs; |
|
63 val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts); |
|
64 val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts); |
|
65 val frees = tnames ~~ Ts; |
|
66 |
|
67 fun mk_prems vs [] = |
|
68 let |
|
69 val rT = nth (rec_result_Ts) i; |
|
70 val vs' = filter_out is_unit vs; |
|
71 val f = mk_Free "f" (map fastype_of vs' ---> rT) j; |
|
72 val f' = Envir.eta_contract (list_abs_free |
|
73 (map dest_Free vs, if i mem is then list_comb (f, vs') |
|
74 else HOLogic.unit)); |
|
75 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs')) |
|
76 (list_comb (Const (cname, Ts ---> T), map Free frees))), f') |
|
77 end |
|
78 | mk_prems vs (((dt, s), T) :: ds) = |
|
79 let |
|
80 val k = body_index dt; |
|
81 val (Us, U) = strip_type T; |
|
82 val i = length Us; |
|
83 val rT = nth (rec_result_Ts) k; |
|
84 val r = Free ("r" ^ s, Us ---> rT); |
|
85 val (p, f) = mk_prems (vs @ [r]) ds |
|
86 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies |
|
87 (list_all (map (pair "x") Us, HOLogic.mk_Trueprop |
|
88 (make_pred k rT U (app_bnds r i) |
|
89 (app_bnds (Free (s, T)) i))), p)), f) |
|
90 end |
|
91 |
|
92 in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end) |
|
93 constrs) (descr ~~ recTs) 1))); |
|
94 |
|
95 fun mk_proj j [] t = t |
|
96 | mk_proj j (i :: is) t = if null is then t else |
|
97 if (j: int) = i then HOLogic.mk_fst t |
|
98 else mk_proj j is (HOLogic.mk_snd t); |
|
99 |
|
100 val tnames = DatatypeProp.make_tnames recTs; |
|
101 val fTs = map fastype_of rec_fns; |
|
102 val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T |
|
103 (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0))) |
|
104 (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names); |
|
105 val r = if null is then Extraction.nullt else |
|
106 foldr1 HOLogic.mk_prod (List.mapPartial (fn (((((i, _), T), U), s), tname) => |
|
107 if i mem is then SOME |
|
108 (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T)) |
|
109 else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames)); |
|
110 val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &") |
|
111 (map (fn ((((i, _), T), U), tname) => |
|
112 make_pred i U T (mk_proj i is r) (Free (tname, T))) |
|
113 (descr ~~ recTs ~~ rec_result_Ts ~~ tnames))); |
|
114 val cert = cterm_of thy; |
|
115 val inst = map (pairself cert) (map head_of (HOLogic.dest_conj |
|
116 (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps); |
|
117 |
|
118 val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl))) |
|
119 (fn prems => |
|
120 [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]), |
|
121 rtac (cterm_instantiate inst induction) 1, |
|
122 ALLGOALS ObjectLogic.atomize_prems_tac, |
|
123 rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites), |
|
124 REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i => |
|
125 REPEAT (etac allE i) THEN atac i)) 1)]); |
|
126 |
|
127 val ind_name = Thm.get_name induction; |
|
128 val vs = map (fn i => List.nth (pnames, i)) is; |
|
129 val (thm', thy') = thy |
|
130 |> Sign.root_path |
|
131 |> PureThy.store_thm |
|
132 (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm) |
|
133 ||> Sign.restore_naming thy; |
|
134 |
|
135 val ivs = rev (Term.add_vars (Logic.varify (DatatypeProp.make_ind [descr] sorts)) []); |
|
136 val rvs = rev (Thm.fold_terms Term.add_vars thm' []); |
|
137 val ivs1 = map Var (filter_out (fn (_, T) => |
|
138 tname_of (body_type T) mem ["set", "bool"]) ivs); |
|
139 val ivs2 = map (fn (ixn, _) => Var (ixn, valOf (AList.lookup (op =) rvs ixn))) ivs; |
|
140 |
|
141 val prf = List.foldr forall_intr_prf |
|
142 (List.foldr (fn ((f, p), prf) => |
|
143 (case head_of (strip_abs_body f) of |
|
144 Free (s, T) => |
|
145 let val T' = Logic.varifyT T |
|
146 in Abst (s, SOME T', Proofterm.prf_abstract_over |
|
147 (Var ((s, 0), T')) (AbsP ("H", SOME p, prf))) |
|
148 end |
|
149 | _ => AbsP ("H", SOME p, prf))) |
|
150 (Proofterm.proof_combP |
|
151 (prf_of thm', map PBound (length prems - 1 downto 0))) (rec_fns ~~ prems_of thm)) ivs2; |
|
152 |
|
153 val r' = if null is then r else Logic.varify (List.foldr (uncurry lambda) |
|
154 r (map Logic.unvarify ivs1 @ filter_out is_unit |
|
155 (map (head_of o strip_abs_body) rec_fns))); |
|
156 |
|
157 in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end; |
|
158 |
|
159 |
|
160 fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : datatype_info) thy = |
|
161 let |
|
162 val cert = cterm_of thy; |
|
163 val rT = TFree ("'P", HOLogic.typeS); |
|
164 val rT' = TVar (("'P", 0), HOLogic.typeS); |
|
165 |
|
166 fun make_casedist_prem T (cname, cargs) = |
|
167 let |
|
168 val Ts = map (typ_of_dtyp descr sorts) cargs; |
|
169 val frees = Name.variant_list ["P", "y"] (DatatypeProp.make_tnames Ts) ~~ Ts; |
|
170 val free_ts = map Free frees; |
|
171 val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT) |
|
172 in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop |
|
173 (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))), |
|
174 HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ |
|
175 list_comb (r, free_ts))))) |
|
176 end; |
|
177 |
|
178 val SOME (_, _, constrs) = AList.lookup (op =) descr index; |
|
179 val T = List.nth (get_rec_types descr sorts, index); |
|
180 val (rs, prems) = split_list (map (make_casedist_prem T) constrs); |
|
181 val r = Const (case_name, map fastype_of rs ---> T --> rT); |
|
182 |
|
183 val y = Var (("y", 0), Logic.legacy_varifyT T); |
|
184 val y' = Free ("y", T); |
|
185 |
|
186 val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems, |
|
187 HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ |
|
188 list_comb (r, rs @ [y']))))) |
|
189 (fn prems => |
|
190 [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1, |
|
191 ALLGOALS (EVERY' |
|
192 [asm_simp_tac (HOL_basic_ss addsimps case_rewrites), |
|
193 resolve_tac prems, asm_simp_tac HOL_basic_ss])]); |
|
194 |
|
195 val exh_name = Thm.get_name exhaustion; |
|
196 val (thm', thy') = thy |
|
197 |> Sign.root_path |
|
198 |> PureThy.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm) |
|
199 ||> Sign.restore_naming thy; |
|
200 |
|
201 val P = Var (("P", 0), rT' --> HOLogic.boolT); |
|
202 val prf = forall_intr_prf (y, forall_intr_prf (P, |
|
203 List.foldr (fn ((p, r), prf) => |
|
204 forall_intr_prf (Logic.legacy_varify r, AbsP ("H", SOME (Logic.varify p), |
|
205 prf))) (Proofterm.proof_combP (prf_of thm', |
|
206 map PBound (length prems - 1 downto 0))) (prems ~~ rs))); |
|
207 val r' = Logic.legacy_varify (Abs ("y", Logic.legacy_varifyT T, |
|
208 list_abs (map dest_Free rs, list_comb (r, |
|
209 map Bound ((length rs - 1 downto 0) @ [length rs]))))); |
|
210 |
|
211 in Extraction.add_realizers_i |
|
212 [(exh_name, (["P"], r', prf)), |
|
213 (exh_name, ([], Extraction.nullt, prf_of exhaustion))] thy' |
|
214 end; |
|
215 |
|
216 fun add_dt_realizers names thy = |
|
217 if ! Proofterm.proofs < 2 then thy |
|
218 else let |
|
219 val _ = message "Adding realizers for induction and case analysis ..." |
|
220 val infos = map (DatatypePackage.the_datatype thy) names; |
|
221 val info :: _ = infos; |
|
222 in |
|
223 thy |
|
224 |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1)) |
|
225 |> fold_rev (make_casedists (#sorts info)) infos |
|
226 end; |
|
227 |
|
228 val setup = DatatypePackage.interpretation add_dt_realizers; |
|
229 |
|
230 end; |
|