1 |
signature LAZY_EVAL = sig
2 |
3 |
datatype pat = AnyPat of indexname | ConsPat of (string * pat list)
4 |
5 |
type constructor = string * int
6 |
7 |
type equation = {
8 |
function : term,
9 |
thm : thm,
10 |
rhs : term,
11 |
pats : pat list
12 |
13 |
14 |
type eval_ctxt' = {
15 |
equations : equation list,
16 |
constructors : constructor list,
17 |
pctxt : Proof.context,
18 |
facts : thm Net.net,
19 |
verbose : bool
20 |
21 |
22 |
type eval_hook = eval_ctxt' -> term -> (term * conv) option
23 |
24 |
type eval_ctxt = {
25 |
ctxt : eval_ctxt',
26 |
hooks : eval_hook list
27 |
28 |
29 |
val is_constructor_name : constructor list -> string -> bool
30 |
val constructor_arity : constructor list -> string -> int option
31 |
32 |
val mk_eval_ctxt : Proof.context -> constructor list -> thm list -> eval_ctxt
33 |
val add_facts : thm list -> eval_ctxt -> eval_ctxt
34 |
val get_facts : eval_ctxt -> thm list
35 |
val get_ctxt : eval_ctxt -> Proof.context
36 |
val add_hook : eval_hook -> eval_ctxt -> eval_ctxt
37 |
val get_verbose : eval_ctxt -> bool
38 |
val set_verbose : bool -> eval_ctxt -> eval_ctxt
39 |
val get_constructors : eval_ctxt -> constructor list
40 |
val set_constructors : constructor list -> eval_ctxt -> eval_ctxt
41 |
42 |
val whnf : eval_ctxt -> term -> term * conv
43 |
val match : eval_ctxt -> pat -> term ->
44 |
(indexname * term) list option -> (indexname * term) list option * term * conv
45 |
val match_all : eval_ctxt -> pat list -> term list ->
46 |
(indexname * term) list option -> (indexname * term) list option * term list * conv
47 |
48 |
49 |
50 |
structure Lazy_Eval : LAZY_EVAL = struct
51 |
52 |
datatype pat = AnyPat of indexname | ConsPat of (string * pat list)
53 |
54 |
type constructor = string * int
55 |
56 |
type equation = {
57 |
function : term,
58 |
thm : thm,
59 |
rhs : term,
60 |
pats : pat list
61 |
62 |
63 |
type eval_ctxt' = {
64 |
equations : equation list,
65 |
constructors : constructor list,
66 |
pctxt : Proof.context,
67 |
facts : thm Net.net,
68 |
verbose : bool
69 |
70 |
71 |
type eval_hook = eval_ctxt' -> term -> (term * conv) option
72 |
73 |
type eval_ctxt = {
74 |
ctxt : eval_ctxt',
75 |
hooks : eval_hook list
76 |
77 |
78 |
fun add_hook h ({hooks, ctxt} : eval_ctxt) =
79 |
{hooks = h :: hooks, ctxt = ctxt} : eval_ctxt
80 |
81 |
fun get_verbose {ctxt = {verbose, ...}, ...} = verbose
82 |
83 |
fun set_verbose b ({ctxt = {equations, pctxt, facts, constructors, ...}, hooks} : eval_ctxt) =
84 |
{ctxt = {equations = equations, pctxt = pctxt, facts = facts,
85 |
constructors = constructors, verbose = b}, hooks = hooks}
86 |
87 |
fun get_constructors ({ctxt = {constructors, ...}, ...} : eval_ctxt) = constructors
88 |
89 |
fun set_constructors cs ({ctxt = {equations, pctxt, facts, verbose, ...}, hooks} : eval_ctxt) =
90 |
{ctxt = {equations = equations, pctxt = pctxt, facts = facts,
91 |
verbose = verbose, constructors = cs}, hooks = hooks}
92 |
93 |
type constructor = string * int
94 |
95 |
val is_constructor_name = member (op = o apsnd fst)
96 |
97 |
val constructor_arity = AList.lookup op =
98 |
99 |
fun stream_pat_of_term _ (Var v) = AnyPat (fst v)
100 |
| stream_pat_of_term constructors t =
101 |
case strip_comb t of
102 |
(Const (c, _), args) =>
103 |
(case constructor_arity constructors c of
104 |
NONE => raise TERM ("Not a valid pattern.", [t])
105 |
| SOME n =>
106 |
if length args = n then
107 |
ConsPat (c, map (stream_pat_of_term constructors) args)
108 |
109 |
raise TERM ("Not a valid pattern.", [t]))
110 |
| _ => raise TERM ("Not a valid pattern.", [t])
111 |
112 |
fun analyze_eq constructors thm =
113 |
114 |
val ((f, pats), rhs) = thm |> Thm.concl_of |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |>
115 |
apfst (strip_comb #> apsnd (map (stream_pat_of_term constructors)))
116 |
handle TERM _ => raise THM ("Not a valid function equation.", 0, [thm])
117 |
118 |
{function = f, thm = thm RS @{thm eq_reflection}, rhs = rhs, pats = pats} : equation
119 |
120 |
121 |
fun mk_eval_ctxt ctxt (constructors : constructor list) thms : eval_ctxt = {
122 |
ctxt = {
123 |
equations = map (analyze_eq constructors) thms,
124 |
facts = Net.empty,
125 |
verbose = false,
126 |
pctxt = ctxt,
127 |
constructors = constructors
128 |
129 |
hooks = []
130 |
131 |
132 |
fun add_facts facts' {ctxt = {equations, pctxt, facts, verbose, constructors}, hooks} =
133 |
134 |
val eq = op = o apply2 Thm.prop_of
135 |
val facts' =
136 |
fold (fn thm => fn net => Net.insert_term eq (Thm.prop_of thm, thm) net
137 |
handle Net.INSERT => net) facts' facts
138 |
139 |
{ctxt = {equations = equations, pctxt = pctxt, facts = facts',
140 |
verbose = verbose, constructors = constructors},
141 |
hooks = hooks}
142 |
143 |
144 |
val get_facts = Net.content o #facts o #ctxt
145 |
146 |
val get_ctxt = (#pctxt o #ctxt : eval_ctxt -> Proof.context)
147 |
148 |
fun find_eqs (eval_ctxt : eval_ctxt) f =
149 |
150 |
fun eq_const (Const (c, _)) (Const (c', _)) = c = c'
151 |
| eq_const _ _ = false
152 |
153 |
map_filter (fn eq => if eq_const f (#function eq) then SOME eq else NONE)
154 |
(#equations (#ctxt eval_ctxt))
155 |
156 |
157 |
datatype ('a, 'b) either = Inl of 'a | Inr of 'b
158 |
159 |
fun whnf (ctxt : eval_ctxt) t =
160 |
case whnf_aux1 ctxt (Envir.beta_norm t) of
161 |
(t', conv) =>
162 |
if t aconv t' then
163 |
(t', conv)
164 |
165 |
case whnf ctxt t' of
166 |
(t'', conv') => (t'', conv then_conv conv')
167 |
168 |
and whnf_aux1 (ctxt as {hooks, ctxt = ctxt'}) t =
169 |
case get_first (fn h => h ctxt' t) hooks of
170 |
NONE => whnf_aux2 ctxt t
171 |
| SOME (t', conv) => case whnf ctxt t' of (t'', conv') =>
172 |
(t'', conv then_conv conv')
173 |
and whnf_aux2 ctxt t =
174 |
175 |
val (f, args) = strip_comb t
176 |
177 |
fun instantiate table (Var (x, _)) = the (AList.lookup op = table x)
178 |
| instantiate table (s $ t) = instantiate table s $ instantiate table t
179 |
| instantiate _ t = t
180 |
fun apply_eq {thm, rhs, pats, ...} conv args =
181 |
182 |
val (table, args', conv') = match_all ctxt pats args (SOME [])
183 |
in (
184 |
case table of
185 |
SOME _ => (
186 |
187 |
val thy = Proof_Context.theory_of (get_ctxt ctxt)
188 |
val t' = list_comb (f, args')
189 |
val lhs = Thm.term_of (Thm.lhs_of thm)
190 |
val env = Pattern.match thy (lhs, t') (Vartab.empty, Vartab.empty)
191 |
val rhs = Thm.term_of (Thm.rhs_of thm)
192 |
val rhs = Envir.subst_term env rhs |> Envir.beta_norm
193 |
194 |
Inr (rhs, conv then_conv conv' then_conv Conv.rewr_conv thm)
195 |
196 |
handle Pattern.MATCH => Inl (args', conv then_conv conv'))
197 |
| NONE => Inl (args', conv then_conv conv'))
198 |
199 |
200 |
fun apply_eqs [] args conv = (list_comb (f, args), conv)
201 |
| apply_eqs (eq :: ctxt) args conv =
202 |
(case apply_eq eq conv args of
203 |
Inr res => res
204 |
| Inl (args', conv) => apply_eqs ctxt args' conv)
205 |
206 |
207 |
case f of
208 |
Const (f', _) =>
209 |
if is_constructor_name (get_constructors ctxt) f' then
210 |
(t, Conv.all_conv)
211 |
212 |
apply_eqs (find_eqs ctxt f) args Conv.all_conv
213 |
| _ => (t, Conv.all_conv)
214 |
215 |
and match_all ctxt pats args table =
216 |
217 |
fun match_all' [] [] acc conv table = (table, rev acc, conv)
218 |
| match_all' (_ :: pats) (arg :: args) acc conv NONE =
219 |
match_all' pats args (arg :: acc) (Conv.fun_conv conv) NONE
220 |
| match_all' (pat :: pats) (arg :: args) acc conv (SOME table) =
221 |
222 |
val (table', arg', conv') = match ctxt pat arg (SOME table)
223 |
val conv = Conv.combination_conv conv conv'
224 |
val acc = arg' :: acc
225 |
226 |
match_all' pats args acc conv table'
227 |
228 |
| match_all' _ _ _ _ _ = raise Match
229 |
230 |
if length pats = length args then
231 |
match_all' pats args [] Conv.all_conv table
232 |
233 |
(NONE, args, Conv.all_conv)
234 |
235 |
and match _ _ t NONE = (NONE, t, Conv.all_conv)
236 |
| match _ (AnyPat v) t (SOME table) = (SOME ((v, t) :: table), t, Conv.all_conv)
237 |
| match ctxt (ConsPat (c, pats)) t (SOME table) =
238 |
239 |
val (t', conv) = whnf ctxt t
240 |
val (f, args) = strip_comb t'
241 |
242 |
case f of
243 |
Const (c', _) =>
244 |
if c = c' then
245 |
case match_all ctxt pats args (SOME table) of
246 |
(table, args', conv') => (table, list_comb (f, args'), conv then_conv conv')
247 |
248 |
(NONE, t', conv)
249 |
| _ => (NONE, t', conv)
250 |
251 |
252 |
end |