5 Command "fun" for fully automated function definitions |
5 Command "fun" for fully automated function definitions |
6 *) |
6 *) |
7 |
7 |
8 signature FUNCTION_FUN = |
8 signature FUNCTION_FUN = |
9 sig |
9 sig |
10 val add_fun : Function_Common.function_config -> |
10 val add_fun : Function_Common.function_config -> |
11 (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> |
11 (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> |
12 bool -> local_theory -> Proof.context |
12 bool -> local_theory -> Proof.context |
13 val add_fun_cmd : Function_Common.function_config -> |
13 val add_fun_cmd : Function_Common.function_config -> |
14 (binding * string option * mixfix) list -> (Attrib.binding * string) list -> |
14 (binding * string option * mixfix) list -> (Attrib.binding * string) list -> |
15 bool -> local_theory -> Proof.context |
15 bool -> local_theory -> Proof.context |
16 |
16 |
17 val setup : theory -> theory |
17 val setup : theory -> theory |
18 end |
18 end |
19 |
19 |
20 structure Function_Fun : FUNCTION_FUN = |
20 structure Function_Fun : FUNCTION_FUN = |
21 struct |
21 struct |
22 |
22 |
23 open Function_Lib |
23 open Function_Lib |
24 open Function_Common |
24 open Function_Common |
25 |
25 |
26 |
26 |
27 fun check_pats ctxt geq = |
27 fun check_pats ctxt geq = |
28 let |
28 let |
29 fun err str = error (cat_lines ["Malformed definition:", |
29 fun err str = error (cat_lines ["Malformed definition:", |
30 str ^ " not allowed in sequential mode.", |
30 str ^ " not allowed in sequential mode.", |
31 Syntax.string_of_term ctxt geq]) |
31 Syntax.string_of_term ctxt geq]) |
32 val thy = ProofContext.theory_of ctxt |
32 val thy = ProofContext.theory_of ctxt |
33 |
33 |
34 fun check_constr_pattern (Bound _) = () |
34 fun check_constr_pattern (Bound _) = () |
35 | check_constr_pattern t = |
35 | check_constr_pattern t = |
36 let |
36 let |
37 val (hd, args) = strip_comb t |
37 val (hd, args) = strip_comb t |
38 in |
38 in |
39 (((case Datatype.info_of_constr thy (dest_Const hd) of |
39 (((case Datatype.info_of_constr thy (dest_Const hd) of |
40 SOME _ => () |
40 SOME _ => () |
41 | NONE => err "Non-constructor pattern") |
41 | NONE => err "Non-constructor pattern") |
42 handle TERM ("dest_Const", _) => err "Non-constructor patterns"); |
42 handle TERM ("dest_Const", _) => err "Non-constructor patterns"); |
43 map check_constr_pattern args; |
43 map check_constr_pattern args; |
44 ()) |
44 ()) |
45 end |
45 end |
46 |
46 |
47 val (_, qs, gs, args, _) = split_def ctxt geq |
47 val (_, qs, gs, args, _) = split_def ctxt geq |
48 |
48 |
49 val _ = if not (null gs) then err "Conditional equations" else () |
49 val _ = if not (null gs) then err "Conditional equations" else () |
50 val _ = map check_constr_pattern args |
50 val _ = map check_constr_pattern args |
51 |
51 |
52 (* just count occurrences to check linearity *) |
52 (* just count occurrences to check linearity *) |
53 val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs |
53 val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs |
54 then err "Nonlinear patterns" else () |
54 then err "Nonlinear patterns" else () |
55 in |
55 in |
56 () |
56 () |
57 end |
57 end |
58 |
58 |
59 val by_pat_completeness_auto = |
59 val by_pat_completeness_auto = |
60 Proof.global_future_terminal_proof |
60 Proof.global_future_terminal_proof |
61 (Method.Basic Pat_Completeness.pat_completeness, |
61 (Method.Basic Pat_Completeness.pat_completeness, |
62 SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none)))) |
62 SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none)))) |
63 |
63 |
64 fun termination_by method int = |
64 fun termination_by method int = |
65 Function.termination_proof NONE |
65 Function.termination_proof NONE |
66 #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int |
66 #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int |
67 |
67 |
68 fun mk_catchall fixes arity_of = |
68 fun mk_catchall fixes arity_of = |
69 let |
69 let |
70 fun mk_eqn ((fname, fT), _) = |
70 fun mk_eqn ((fname, fT), _) = |
71 let |
71 let |
72 val n = arity_of fname |
72 val n = arity_of fname |
73 val (argTs, rT) = chop n (binder_types fT) |
73 val (argTs, rT) = chop n (binder_types fT) |
74 |> apsnd (fn Ts => Ts ---> body_type fT) |
74 |> apsnd (fn Ts => Ts ---> body_type fT) |
75 |
75 |
76 val qs = map Free (Name.invent_list [] "a" n ~~ argTs) |
76 val qs = map Free (Name.invent_list [] "a" n ~~ argTs) |
77 in |
77 in |
78 HOLogic.mk_eq(list_comb (Free (fname, fT), qs), |
78 HOLogic.mk_eq(list_comb (Free (fname, fT), qs), |
79 Const ("HOL.undefined", rT)) |
79 Const ("HOL.undefined", rT)) |
80 |> HOLogic.mk_Trueprop |
80 |> HOLogic.mk_Trueprop |
81 |> fold_rev Logic.all qs |
81 |> fold_rev Logic.all qs |
82 end |
82 end |
83 in |
83 in |
84 map mk_eqn fixes |
84 map mk_eqn fixes |
85 end |
85 end |
86 |
86 |
87 fun add_catchall ctxt fixes spec = |
87 fun add_catchall ctxt fixes spec = |
88 let val fqgars = map (split_def ctxt) spec |
88 let val fqgars = map (split_def ctxt) spec |
89 val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars |
89 val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars |
90 |> AList.lookup (op =) #> the |
90 |> AList.lookup (op =) #> the |
91 in |
91 in |
92 spec @ mk_catchall fixes arity_of |
92 spec @ mk_catchall fixes arity_of |
93 end |
93 end |
94 |
94 |
95 fun warn_if_redundant ctxt origs tss = |
95 fun warn_if_redundant ctxt origs tss = |
|
96 let |
|
97 fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t) |
|
98 |
|
99 val (tss', _) = chop (length origs) tss |
|
100 fun check (t, []) = (warning (msg t); []) |
|
101 | check (t, s) = s |
|
102 in |
|
103 (map check (origs ~~ tss'); tss) |
|
104 end |
|
105 |
|
106 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec = |
|
107 if sequential then |
96 let |
108 let |
97 fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t) |
109 val (bnds, eqss) = split_list spec |
98 |
110 |
99 val (tss', _) = chop (length origs) tss |
111 val eqs = map the_single eqss |
100 fun check (t, []) = (warning (msg t); []) |
112 |
101 | check (t, s) = s |
113 val feqs = eqs |
102 in |
114 |> tap (check_defs ctxt fixes) (* Standard checks *) |
103 (map check (origs ~~ tss'); tss) |
115 |> tap (map (check_pats ctxt)) (* More checks for sequential mode *) |
104 end |
116 |
|
117 val compleqs = add_catchall ctxt fixes feqs (* Completion *) |
|
118 |
|
119 val spliteqs = warn_if_redundant ctxt feqs |
|
120 (Function_Split.split_all_equations ctxt compleqs) |
|
121 |
|
122 fun restore_spec thms = |
|
123 bnds ~~ take (length bnds) (unflat spliteqs thms) |
|
124 |
|
125 val spliteqs' = flat (take (length bnds) spliteqs) |
|
126 val fnames = map (fst o fst) fixes |
|
127 val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs' |
|
128 |
|
129 fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs) |
|
130 |> map (map snd) |
105 |
131 |
106 |
132 |
107 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec = |
133 val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding |
108 if sequential then |
|
109 let |
|
110 val (bnds, eqss) = split_list spec |
|
111 |
|
112 val eqs = map the_single eqss |
|
113 |
|
114 val feqs = eqs |
|
115 |> tap (check_defs ctxt fixes) (* Standard checks *) |
|
116 |> tap (map (check_pats ctxt)) (* More checks for sequential mode *) |
|
117 |
134 |
118 val compleqs = add_catchall ctxt fixes feqs (* Completion *) |
135 (* using theorem names for case name currently disabled *) |
119 |
136 val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) |
120 val spliteqs = warn_if_redundant ctxt feqs |
137 (bnds' ~~ spliteqs) |> flat |
121 (Function_Split.split_all_equations ctxt compleqs) |
138 in |
122 |
139 (flat spliteqs, restore_spec, sort, case_names) |
123 fun restore_spec thms = |
140 end |
124 bnds ~~ take (length bnds) (unflat spliteqs thms) |
141 else |
125 |
142 Function_Common.empty_preproc check_defs config ctxt fixes spec |
126 val spliteqs' = flat (take (length bnds) spliteqs) |
|
127 val fnames = map (fst o fst) fixes |
|
128 val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs' |
|
129 |
|
130 fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs) |
|
131 |> map (map snd) |
|
132 |
|
133 |
|
134 val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding |
|
135 |
|
136 (* using theorem names for case name currently disabled *) |
|
137 val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) |
|
138 (bnds' ~~ spliteqs) |
|
139 |> flat |
|
140 in |
|
141 (flat spliteqs, restore_spec, sort, case_names) |
|
142 end |
|
143 else |
|
144 Function_Common.empty_preproc check_defs config ctxt fixes spec |
|
145 |
143 |
146 val setup = |
144 val setup = |
147 Context.theory_map (Function_Common.set_preproc sequential_preproc) |
145 Context.theory_map (Function_Common.set_preproc sequential_preproc) |
148 |
146 |
149 |
147 |
150 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), |
148 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), |
151 domintros=false, partials=false, tailrec=false } |
149 domintros=false, partials=false, tailrec=false } |
152 |
150 |
153 fun gen_fun add config fixes statements int lthy = |
151 fun gen_fun add config fixes statements int lthy = |
154 lthy |
152 lthy |
155 |> add fixes statements config |
153 |> add fixes statements config |
156 |> by_pat_completeness_auto int |
154 |> by_pat_completeness_auto int |
157 |> Local_Theory.restore |
155 |> Local_Theory.restore |
158 |> termination_by (Function_Common.get_termination_prover lthy) int |
156 |> termination_by (Function_Common.get_termination_prover lthy) int |
159 |
157 |
160 val add_fun = gen_fun Function.add_function |
158 val add_fun = gen_fun Function.add_function |
161 val add_fun_cmd = gen_fun Function.add_function_cmd |
159 val add_fun_cmd = gen_fun Function.add_function_cmd |
162 |
160 |
163 |
161 |