26 |
26 |
27 fun primrec_error str = raise Primrec_Error (str, []); |
27 fun primrec_error str = raise Primrec_Error (str, []); |
28 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]); |
28 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]); |
29 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns); |
29 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns); |
30 |
30 |
|
31 val free_name = try (fn Free (v, _) => v); |
|
32 val const_name = try (fn Const (v, _) => v); |
|
33 |
31 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); |
34 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); |
32 fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else |
35 fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else |
33 strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda; |
36 strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda; |
34 |
37 |
35 val simp_attrs = @{attributes [simp]}; |
38 val simp_attrs = @{attributes [simp]}; |
36 |
39 |
37 |
40 |
101 res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs, |
104 res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs, |
102 rhs_term = rhs, |
105 rhs_term = rhs, |
103 user_eqn = eqn'} |
106 user_eqn = eqn'} |
104 end; |
107 end; |
105 |
108 |
106 fun rewrite_map_arg funs_data get_indices y rec_type res_type = |
109 fun rewrite_map_arg funs_data get_indices rec_type res_type = |
107 let |
110 let |
|
111 val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data)); |
|
112 val fun_name = #fun_name fun_data; |
|
113 val ctr_pos = length (#left_args fun_data); |
|
114 |
108 val pT = HOLogic.mk_prodT (rec_type, res_type); |
115 val pT = HOLogic.mk_prodT (rec_type, res_type); |
109 val fstx = fst_const pT; |
116 |
110 val sndx = snd_const pT; |
117 val maybe_suc = Option.map (fn x => x + 1); |
111 |
118 fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT) |
112 val SOME ({fun_name, left_args, ...} :: _) = |
119 | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b) |
113 find_first (equal rec_type o #rec_type o hd) funs_data; |
120 | subst d t = |
114 val ctr_pos = length left_args; |
|
115 |
|
116 fun subst _ d (t as Bound d') = t |> d = d' ? curry (op $) fstx |
|
117 | subst l d (Abs (v, T, b)) = Abs (v, if d < 0 then pT else T, subst l (d + 1) b) |
|
118 | subst l d t = |
|
119 let val (u, vs) = strip_comb t in |
121 let val (u, vs) = strip_comb t in |
120 if try (fst o dest_Free) u = SOME fun_name then |
122 if free_name u = SOME fun_name then |
121 if l andalso length vs = ctr_pos then |
123 if d = SOME ~1 andalso length vs = ctr_pos then |
122 list_comb (sndx |> permute_args ctr_pos, vs) |
124 list_comb (permute_args ctr_pos (snd_const pT), vs) |
123 else if length vs <= ctr_pos then |
125 else if length vs > ctr_pos andalso is_some d |
124 primrec_error_eqn "too few arguments in recursive call" t |
126 andalso d = try (fn Bound n => n) (nth vs ctr_pos) then |
125 else if nth vs ctr_pos |> member (op =) [y, Bound d] then |
127 list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs)) |
126 list_comb (sndx $ nth vs ctr_pos, nth_drop ctr_pos vs |> map (subst false d)) |
|
127 else |
128 else |
128 primrec_error_eqn "recursive call not directly applied to constructor argument" t |
129 primrec_error_eqn ("recursive call not directly applied to constructor argument") t |
129 else if try (fst o dest_Const) u = SOME @{const_name comp} then |
130 else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then |
130 (hd vs |> get_indices |> null orelse |
131 list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs) |
131 primrec_error_eqn "recursive call not directly applied to constructor argument" t; |
|
132 list_comb |
|
133 (u |> map_types (strip_type #>> (fn Ts => Ts |
|
134 |> nth_map (length Ts - 1) (K pT) |
|
135 |> nth_map (length Ts - 2) (strip_type #>> nth_map 0 (K pT) #> (op --->))) |
|
136 #> (op --->)), |
|
137 nth_map 1 (subst l d) vs)) |
|
138 else |
132 else |
139 list_comb (u, map (subst false d) vs) |
133 list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs) |
140 end |
134 end |
141 in |
135 in |
142 subst true ~1 |
136 subst (SOME ~1) |
143 end; |
137 end; |
144 |
138 |
145 (* FIXME get rid of funs_data or get_indices *) |
139 (* FIXME get rid of funs_data or get_indices *) |
146 fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t = |
140 fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t = |
147 let |
141 let |
162 orelse primrec_error_eqn "too few arguments in recursive call" t; |
156 orelse primrec_error_eqn "too few arguments in recursive call" t; |
163 list_comb (the maybe_direct_y', g_args)) |
157 list_comb (the maybe_direct_y', g_args)) |
164 else if is_some maybe_indirect_y' then |
158 else if is_some maybe_indirect_y' then |
165 (if contains_fun g then t else y) |
159 (if contains_fun g then t else y) |
166 |> massage_indirect_rec_call lthy contains_fun |
160 |> massage_indirect_rec_call lthy contains_fun |
167 (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y') |
161 (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y') |
168 |> (if contains_fun g then I else curry (op $) g) |
162 |> (if contains_fun g then I else curry (op $) g) |
169 else |
163 else |
170 t |
164 t |
171 end |
165 end |
172 | subst _ t = t |
166 | subst _ t = t |
424 if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs; |
418 if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs; |
425 val ctr_no = if not_disc then 1 - ctr_no' else ctr_no'; |
419 val ctr_no = if not_disc then 1 - ctr_no' else ctr_no'; |
426 val fun_args = if is_none disc |
420 val fun_args = if is_none disc |
427 then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd |
421 then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd |
428 else the disc |> the_single o snd o strip_comb |
422 else the disc |> the_single o snd o strip_comb |
429 |> (fn t => if try (fst o dest_Free o head_of) t = SOME fun_name |
423 |> (fn t => if free_name (head_of t) = SOME fun_name |
430 then snd (strip_comb t) else []); |
424 then snd (strip_comb t) else []); |
431 |
425 |
432 val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True}; |
426 val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True}; |
433 val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False}; |
427 val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False}; |
434 val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_; |
428 val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_; |