| author | wenzelm | 
| Wed, 15 Jul 2020 11:56:43 +0200 | |
| changeset 72034 | 452073b64f28 | 
| parent 69593 | 3dda49e08b9d | 
| child 80634 | a90ab1ea6458 | 
| permissions | -rw-r--r-- | 
| 68155 | 1 | (* Author: Pascal Stoop, ETH Zurich | 
| 2 | Author: Andreas Lochbihler, Digital Asset *) | |
| 3 | ||
| 4 | signature CASE_CONVERTER = | |
| 5 | sig | |
| 69568 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 6 | type elimination_strategy | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 7 | val to_case: Proof.context -> elimination_strategy -> (string * typ -> int) -> | 
| 68155 | 8 | thm list -> thm list option | 
| 69568 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 9 | val replace_by_type: (Proof.context -> string * string -> bool) -> elimination_strategy | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 10 | val keep_constructor_context: elimination_strategy | 
| 68155 | 11 | end; | 
| 12 | ||
| 13 | structure Case_Converter : CASE_CONVERTER = | |
| 14 | struct | |
| 15 | ||
| 16 | fun lookup_remove _ _ [] = (NONE, []) | |
| 17 | | lookup_remove eq k ((k', v) :: kvs) = | |
| 18 | if eq (k, k') then (SOME (k', v), kvs) | |
| 19 | else apsnd (cons (k', v)) (lookup_remove eq k kvs) | |
| 20 | ||
| 21 | fun mk_abort msg t = | |
| 22 | let | |
| 23 | val T = fastype_of t | |
| 69593 | 24 | val abort = Const (\<^const_name>\<open>missing_pattern_match\<close>, HOLogic.literalT --> (HOLogic.unitT --> T) --> T) | 
| 68155 | 25 | in | 
| 26 | abort $ HOLogic.mk_literal msg $ absdummy HOLogic.unitT t | |
| 27 | end | |
| 28 | ||
| 29 | (* fold_term : (string * typ -> 'a) -> | |
| 30 | (string * typ -> 'a) -> | |
| 31 | (indexname * typ -> 'a) -> | |
| 32 | (int -> 'a) -> | |
| 33 | (string * typ * 'a -> 'a) -> | |
| 34 |                ('a * 'a -> 'a) ->
 | |
| 35 | term -> | |
| 36 | 'a *) | |
| 37 | fun fold_term const_fun free_fun var_fun bound_fun abs_fun dollar_fun term = | |
| 38 | let | |
| 39 | fun go x = case x of | |
| 40 | Const (s, T) => const_fun (s, T) | |
| 41 | | Free (s, T) => free_fun (s, T) | |
| 42 | | Var (i, T) => var_fun (i, T) | |
| 43 | | Bound n => bound_fun n | |
| 44 | | Abs (s, T, term) => abs_fun (s, T, go term) | |
| 45 | | term1 $ term2 => dollar_fun (go term1, go term2) | |
| 46 | in | |
| 47 | go term | |
| 48 | end; | |
| 49 | ||
| 50 | datatype term_coordinate = End of typ | |
| 51 | | Coordinate of (string * (int * term_coordinate)) list; | |
| 52 | ||
| 53 | fun term_coordinate_merge (End T) _ = End T | |
| 54 | | term_coordinate_merge _ (End T) = End T | |
| 55 | | term_coordinate_merge (Coordinate xs) (Coordinate ys) = | |
| 56 | let | |
| 57 | fun merge_consts xs [] = xs | |
| 58 | | merge_consts xs ((s1, (n, y)) :: ys) = | |
| 59 | case List.partition (fn (s2, (m, _)) => s1 = s2 andalso n = m) xs of | |
| 60 | ([], xs') => (s1, (n, y)) :: (merge_consts xs' ys) | |
| 61 | | ((_, (_, x)) :: _, xs') => (s1, (n, term_coordinate_merge x y)) :: (merge_consts xs' ys) | |
| 62 | in | |
| 63 | Coordinate (merge_consts xs ys) | |
| 64 | end; | |
| 65 | ||
| 66 | fun coordinates_to_list (End x) = [(x, [])] | |
| 67 | | coordinates_to_list (Coordinate xs) = | |
| 68 | let | |
| 69 | fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss) | |
| 70 | in flat (map f xs) end; | |
| 71 | ||
| 69568 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 72 | type elimination_strategy = Proof.context -> term list -> term_coordinate list | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 73 | |
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 74 | fun replace_by_type replace_ctr ctxt pats = | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 75 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 76 | fun term_to_coordinates P term = | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 77 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 78 | val (ctr, args) = strip_comb term | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 79 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 80 | case ctr of Const (s, T) => | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 81 | if P (body_type T |> dest_Type |> fst, s) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 82 | then SOME (End (body_type T)) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 83 | else | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 84 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 85 | fun f (i, t) = term_to_coordinates P t |> Option.map (pair i) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 86 | val tcos = map_filter I (map_index f args) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 87 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 88 | if null tcos then NONE | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 89 | else SOME (Coordinate (map (pair s) tcos)) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 90 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 91 | | _ => NONE | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 92 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 93 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 94 | map_filter (term_to_coordinates (replace_ctr ctxt)) pats | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 95 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 96 | |
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 97 | fun keep_constructor_context ctxt pats = | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 98 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 99 | fun to_coordinates [] = NONE | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 100 | | to_coordinates pats = | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 101 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 102 | val (fs, argss) = map strip_comb pats |> split_list | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 103 | val f = hd fs | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 104 | fun is_single_ctr (Const (name, T)) = | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 105 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 106 | val tyco = body_type T |> dest_Type |> fst | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 107 | val _ = Ctr_Sugar.ctr_sugar_of ctxt tyco |> the |> #ctrs | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 108 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 109 | case Ctr_Sugar.ctr_sugar_of ctxt tyco of | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 110 |                   NONE => error ("Not a free constructor " ^ name ^ " in pattern")
 | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 111 | | SOME info => | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 112 | case #ctrs info of [Const (name', _)] => name = name' | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 113 | | _ => false | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 114 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 115 | | is_single_ctr _ = false | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 116 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 117 | if not (is_single_ctr f) andalso forall (fn x => f = x) fs then | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 118 | let | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 119 | val patss = Ctr_Sugar_Util.transpose argss | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 120 | fun recurse (i, pats) = to_coordinates pats |> Option.map (pair i) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 121 | val coords = map_filter I (map_index recurse patss) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 122 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 123 | if null coords then NONE | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 124 | else SOME (Coordinate (map (pair (dest_Const f |> fst)) coords)) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 125 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 126 | else SOME (End (body_type (fastype_of f))) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 127 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 128 | in | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 129 | the_list (to_coordinates pats) | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 130 | end | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 131 | |
| 68155 | 132 | |
| 133 | (* AL: TODO: change from term to const_name *) | |
| 134 | fun find_ctr ctr1 xs = | |
| 135 | let | |
| 136 | val const_name = fst o dest_Const | |
| 137 | fun const_equal (ctr1, ctr2) = const_name ctr1 = const_name ctr2 | |
| 138 | in | |
| 139 | lookup_remove const_equal ctr1 xs | |
| 140 | end; | |
| 141 | ||
| 142 | datatype pattern | |
| 143 | = Wildcard | |
| 144 | | Value | |
| 145 | | Split of int * (term * pattern) list * pattern; | |
| 146 | ||
| 147 | fun pattern_merge Wildcard pat' = pat' | |
| 148 | | pattern_merge Value _ = Value | |
| 149 | | pattern_merge (Split (n, xs, pat)) Wildcard = | |
| 150 | Split (n, map (apsnd (fn pat'' => pattern_merge pat'' Wildcard)) xs, pattern_merge pat Wildcard) | |
| 151 | | pattern_merge (Split _) Value = Value | |
| 152 | | pattern_merge (Split (n, xs, pat)) (Split (m, ys, pat'')) = | |
| 153 | let | |
| 154 | fun merge_consts xs [] = map (apsnd (fn pat => pattern_merge pat Wildcard)) xs | |
| 155 | | merge_consts xs ((ctr, y) :: ys) = | |
| 156 | (case find_ctr ctr xs of | |
| 157 | (SOME (ctr, x), xs) => (ctr, pattern_merge x y) :: merge_consts xs ys | |
| 158 | | (NONE, xs) => (ctr, y) :: merge_consts xs ys | |
| 159 | ) | |
| 160 | in | |
| 161 | Split (if n <= 0 then m else n, merge_consts xs ys, pattern_merge pat pat'') | |
| 162 | end | |
| 163 | ||
| 164 | fun pattern_intersect Wildcard _ = Wildcard | |
| 165 | | pattern_intersect Value pat2 = pat2 | |
| 166 | | pattern_intersect (Split _) Wildcard = Wildcard | |
| 167 | | pattern_intersect (Split (n, xs', pat1)) Value = | |
| 168 | Split (n, | |
| 169 | map (apsnd (fn pat1 => pattern_intersect pat1 Value)) xs', | |
| 170 | pattern_intersect pat1 Value) | |
| 171 | | pattern_intersect (Split (n, xs', pat1)) (Split (m, ys, pat2)) = | |
| 172 | Split (if n <= 0 then m else n, | |
| 173 | intersect_consts xs' ys pat1 pat2, | |
| 174 | pattern_intersect pat1 pat2) | |
| 175 | and | |
| 176 | intersect_consts xs [] _ default2 = map (apsnd (fn pat => pattern_intersect pat default2)) xs | |
| 177 | | intersect_consts xs ((ctr, pat2) :: ys) default1 default2 = case find_ctr ctr xs of | |
| 178 | (SOME (ctr, pat1), xs') => | |
| 179 | (ctr, pattern_merge (pattern_merge (pattern_intersect pat1 pat2) (pattern_intersect default1 pat2)) | |
| 180 | (pattern_intersect pat1 default2)) :: | |
| 181 | intersect_consts xs' ys default1 default2 | |
| 182 | | (NONE, xs') => (ctr, pattern_intersect default1 pat2) :: (intersect_consts xs' ys default1 default2) | |
| 183 | ||
| 184 | fun pattern_lookup _ Wildcard = Wildcard | |
| 185 | | pattern_lookup _ Value = Value | |
| 186 | | pattern_lookup [] (Split (n, xs, pat)) = | |
| 187 | Split (n, map (apsnd (pattern_lookup [])) xs, pattern_lookup [] pat) | |
| 188 | | pattern_lookup (term :: terms) (Split (n, xs, pat)) = | |
| 189 | let | |
| 190 | val (ctr, args) = strip_comb term | |
| 191 | fun map_ctr (term, pat) = | |
| 192 | let | |
| 193 |         val args = term |> dest_Const |> snd |> binder_types |> map (fn T => Free ("x", T))
 | |
| 194 | in | |
| 195 | pattern_lookup args pat | |
| 196 | end | |
| 197 | in | |
| 198 | if is_Const ctr then | |
| 199 | case find_ctr ctr xs of (SOME (_, pat'), _) => | |
| 200 | pattern_lookup terms (pattern_merge (pattern_lookup args pat') (pattern_lookup [] pat)) | |
| 201 | | (NONE, _) => pattern_lookup terms pat | |
| 202 | else if length xs < n orelse n <= 0 then | |
| 203 | pattern_lookup terms pat | |
| 204 | else pattern_lookup terms | |
| 205 | (pattern_merge | |
| 206 | (fold pattern_intersect (map map_ctr (tl xs)) (map_ctr (hd xs))) | |
| 207 | (pattern_lookup [] pat)) | |
| 208 | end; | |
| 209 | ||
| 210 | fun pattern_contains terms pat = case pattern_lookup terms pat of | |
| 211 | Wildcard => false | |
| 212 | | Value => true | |
| 213 | | Split _ => raise Match; | |
| 214 | ||
| 215 | fun pattern_create _ [] = Wildcard | |
| 216 | | pattern_create ctr_count (term :: terms) = | |
| 217 | let | |
| 218 | val (ctr, args) = strip_comb term | |
| 219 | in | |
| 220 | if is_Const ctr then | |
| 221 | Split (ctr_count ctr, [(ctr, pattern_create ctr_count (args @ terms))], Wildcard) | |
| 222 | else Split (0, [], pattern_create ctr_count terms) | |
| 223 | end; | |
| 224 | ||
| 225 | fun pattern_insert ctr_count terms pat = | |
| 226 | let | |
| 227 | fun new_pattern terms = pattern_insert ctr_count terms (pattern_create ctr_count terms) | |
| 228 | fun aux _ false Wildcard = Wildcard | |
| 229 | | aux terms true Wildcard = if null terms then Value else new_pattern terms | |
| 230 | | aux _ _ Value = Value | |
| 231 | | aux terms modify (Split (n, xs', pat)) = | |
| 232 | let | |
| 233 | val unmodified = (n, map (apsnd (aux [] false)) xs', aux [] false pat) | |
| 234 | in case terms of [] => Split unmodified | |
| 235 | | term :: terms => | |
| 236 | let | |
| 237 | val (ctr, args) = strip_comb term | |
| 238 | val (m, ys, pat') = unmodified | |
| 239 | in | |
| 240 | if is_Const ctr | |
| 241 | then case find_ctr ctr xs' of | |
| 242 | (SOME (ctr, pat''), xs) => | |
| 243 | Split (m, (ctr, aux (args @ terms) modify pat'') :: map (apsnd (aux [] false)) xs, pat') | |
| 244 | | (NONE, _) => if modify | |
| 245 | then if m <= 0 | |
| 246 | then Split (ctr_count ctr, (ctr, new_pattern (args @ terms)) :: ys, pat') | |
| 247 | else Split (m, (ctr, new_pattern (args @ terms)) :: ys, pat') | |
| 248 | else Split unmodified | |
| 249 | else Split (m, ys, aux terms modify pat) | |
| 250 | end | |
| 251 | end | |
| 252 | in | |
| 253 | aux terms true pat | |
| 254 | end; | |
| 255 | ||
| 256 | val pattern_empty = Wildcard; | |
| 257 | ||
| 258 | fun replace_frees lhss rhss typ_list ctxt = | |
| 259 | let | |
| 260 | fun replace_frees_once (lhs, rhs) ctxt = | |
| 261 | let | |
| 262 | val add_frees_list = fold_rev Term.add_frees | |
| 263 | val frees = add_frees_list lhs [] | |
| 264 | val (new_frees, ctxt1) = (Ctr_Sugar_Util.mk_Frees "x" (map snd frees) ctxt) | |
| 265 | val (new_frees1, ctxt2) = | |
| 266 | let | |
| 267 | val (dest_frees, types) = split_list (map dest_Free new_frees) | |
| 268 | val (new_frees, ctxt2) = Variable.variant_fixes dest_frees ctxt1 | |
| 269 | in | |
| 270 | (map Free (new_frees ~~ types), ctxt2) | |
| 271 | end | |
| 272 | val dict = frees ~~ new_frees1 | |
| 273 | fun free_map_fun (s, T) = | |
| 274 | case AList.lookup (op =) dict (s, T) of | |
| 275 | NONE => Free (s, T) | |
| 276 | | SOME x => x | |
| 277 | val map_fun = fold_term Const free_map_fun Var Bound Abs (op $) | |
| 278 | in | |
| 279 | ((map map_fun lhs, map_fun rhs), ctxt2) | |
| 280 | end | |
| 281 | ||
| 282 | fun variant_fixes (def_frees, ctxt) = | |
| 283 | let | |
| 284 | val (dest_frees, types) = split_list (map dest_Free def_frees) | |
| 285 | val (def_frees, ctxt1) = Variable.variant_fixes dest_frees ctxt | |
| 286 | in | |
| 287 | (map Free (def_frees ~~ types), ctxt1) | |
| 288 | end | |
| 289 | val (def_frees, ctxt1) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" typ_list ctxt) | |
| 290 | val (rhs_frees, ctxt2) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" typ_list ctxt1) | |
| 291 | val (case_args, ctxt3) = variant_fixes (Ctr_Sugar_Util.mk_Frees "x" | |
| 292 | (map fastype_of (hd lhss)) ctxt2) | |
| 293 | val (new_terms1, ctxt4) = fold_map replace_frees_once (lhss ~~ rhss) ctxt3 | |
| 294 | val (lhss1, rhss1) = split_list new_terms1 | |
| 295 | in | |
| 296 | (lhss1, rhss1, def_frees ~~ rhs_frees, case_args, ctxt4) | |
| 297 | end; | |
| 298 | ||
| 299 | fun add_names_in_type (Type (name, Ts)) = | |
| 300 | List.foldr (op o) (Symtab.update (name, ())) (map add_names_in_type Ts) | |
| 301 | | add_names_in_type (TFree _) = I | |
| 302 | | add_names_in_type (TVar _) = I | |
| 303 | ||
| 304 | fun add_names_in_term (Const (_, T)) = add_names_in_type T | |
| 305 | | add_names_in_term (Free (_, T)) = add_names_in_type T | |
| 306 | | add_names_in_term (Var (_, T)) = add_names_in_type T | |
| 307 | | add_names_in_term (Bound _) = I | |
| 308 | | add_names_in_term (Abs (_, T, body)) = | |
| 309 | add_names_in_type T o add_names_in_term body | |
| 310 | | add_names_in_term (t1 $ t2) = add_names_in_term t1 o add_names_in_term t2 | |
| 311 | ||
| 312 | fun add_type_names terms = | |
| 313 | fold (fn term => fn f => add_names_in_term term o f) terms I | |
| 314 | ||
| 315 | fun get_split_theorems ctxt = | |
| 316 | Symtab.keys | |
| 317 | #> map_filter (Ctr_Sugar.ctr_sugar_of ctxt) | |
| 318 | #> map #split; | |
| 319 | ||
| 320 | fun match (Const (s1, _)) (Const (s2, _)) = if s1 = s2 then SOME I else NONE | |
| 321 | | match (Free y) x = SOME (fn z => if z = Free y then x else z) | |
| 322 | | match (pat1 $ pattern2) (t1 $ t2) = | |
| 323 | (case (match pat1 t1, match pattern2 t2) of | |
| 324 | (SOME f, SOME g) => SOME (f o g) | |
| 325 | | _ => NONE | |
| 326 | ) | |
| 327 | | match _ _ = NONE; | |
| 328 | ||
| 329 | fun match_all patterns terms = | |
| 330 | let | |
| 331 | fun combine _ NONE = NONE | |
| 332 | | combine (f_opt, f_opt') (SOME g) = | |
| 333 | case match f_opt f_opt' of SOME f => SOME (f o g) | _ => NONE | |
| 334 | in | |
| 335 | fold_rev combine (patterns ~~ terms) (SOME I) | |
| 336 | end | |
| 337 | ||
| 338 | fun matches (Const (s1, _)) (Const (s2, _)) = s1 = s2 | |
| 339 | | matches (Free _) _ = true | |
| 340 | | matches (pat1 $ pat2) (t1 $ t2) = matches pat1 t1 andalso matches pat2 t2 | |
| 341 | | matches _ _ = false; | |
| 342 | fun matches_all patterns terms = forall (uncurry matches) (patterns ~~ terms) | |
| 343 | ||
| 344 | fun terms_to_case_at ctr_count ctxt (fun_t : term) (default_lhs : term list) | |
| 345 | (pos, (lazy_case_arg, rhs_free)) | |
| 346 | ((lhss : term list list), (rhss : term list), type_name_fun) = | |
| 347 | let | |
| 348 | fun abort t = | |
| 349 | let | |
| 350 | val fun_name = head_of t |> dest_Const |> fst | |
| 351 | val msg = "Missing pattern in " ^ fun_name ^ "." | |
| 352 | in | |
| 353 | mk_abort msg t | |
| 354 | end; | |
| 355 | ||
| 356 | (* Step 1 : Eliminate lazy pattern *) | |
| 357 | fun replace_pat_at (n, tcos) pat pats = | |
| 358 | let | |
| 359 | fun map_at _ _ [] = raise Empty | |
| 360 | | map_at n f (x :: xs) = if n > 0 | |
| 361 | then apfst (cons x) (map_at (n - 1) f xs) | |
| 362 | else apfst (fn x => x :: xs) (f x) | |
| 363 | fun replace [] pat term = (pat, term) | |
| 364 | | replace ((s1, n) :: tcos) pat term = | |
| 365 | let | |
| 366 | val (ctr, args) = strip_comb term | |
| 367 | in | |
| 368 | case ctr of Const (s2, _) => | |
| 369 | if s1 = s2 | |
| 370 | then apfst (pair ctr #> list_comb) (map_at n (replace tcos pat) args) | |
| 371 | else (term, rhs_free) | |
| 372 | | _ => (term, rhs_free) | |
| 373 | end | |
| 374 | val (part1, (old_pat, part2)) = chop n pats ||> (fn xs => (hd xs, tl xs)) | |
| 375 | val (new_pat, old_pat1) = replace tcos pat old_pat | |
| 376 | in | |
| 377 | (part1 @ [new_pat] @ part2, old_pat1) | |
| 378 | end | |
| 379 | val (lhss1, lazy_pats) = map (replace_pat_at pos lazy_case_arg) lhss | |
| 380 | |> split_list | |
| 381 | ||
| 382 | (* Step 2 : Split patterns *) | |
| 383 | fun split equs = | |
| 384 | let | |
| 385 | fun merge_pattern (Const (s1, T1), Const (s2, _)) = | |
| 386 | if s1 = s2 then SOME (Const (s1, T1)) else NONE | |
| 387 | | merge_pattern (t, Free _) = SOME t | |
| 388 | | merge_pattern (Free _, t) = SOME t | |
| 389 | | merge_pattern (t1l $ t1r, t2l $ t2r) = | |
| 390 | (case (merge_pattern (t1l, t2l), merge_pattern (t1r, t2r)) of | |
| 391 | (SOME t1, SOME t2) => SOME (t1 $ t2) | |
| 392 | | _ => NONE) | |
| 393 | | merge_pattern _ = NONE | |
| 394 | fun merge_patterns pats1 pats2 = case (pats1, pats2) of | |
| 395 | ([], []) => SOME [] | |
| 396 | | (x :: xs, y :: ys) => | |
| 397 | (case (merge_pattern (x, y), merge_patterns xs ys) of | |
| 398 | (SOME x, SOME xs) => SOME (x :: xs) | |
| 399 | | _ => NONE | |
| 400 | ) | |
| 401 | | _ => raise Match | |
| 402 | fun merge_insert ((lhs1, case_pat), _) [] = | |
| 403 | [(lhs1, pattern_empty |> pattern_insert ctr_count [case_pat])] | |
| 404 | | merge_insert ((lhs1, case_pat), rhs) ((lhs2, pat) :: pats) = | |
| 405 | let | |
| 406 | val pats = merge_insert ((lhs1, case_pat), rhs) pats | |
| 407 | val (first_equ_needed, new_lhs) = case merge_patterns lhs1 lhs2 of | |
| 408 | SOME new_lhs => (not (pattern_contains [case_pat] pat), new_lhs) | |
| 409 | | NONE => (false, lhs2) | |
| 410 | val second_equ_needed = not (matches_all lhs1 lhs2) | |
| 411 | orelse not first_equ_needed | |
| 412 | val first_equ = if first_equ_needed | |
| 413 | then [(new_lhs, pattern_insert ctr_count [case_pat] pat)] | |
| 414 | else [] | |
| 415 | val second_equ = if second_equ_needed | |
| 416 | then [(lhs2, pat)] | |
| 417 | else [] | |
| 418 | in | |
| 419 | first_equ @ second_equ @ pats | |
| 420 | end | |
| 421 | in | |
| 422 | (fold merge_insert equs [] | |
| 423 | |> split_list | |
| 424 | |> fst) @ [default_lhs] | |
| 425 | end | |
| 426 | val lhss2 = split ((lhss1 ~~ lazy_pats) ~~ rhss) | |
| 427 | ||
| 428 | (* Step 3 : Remove redundant patterns *) | |
| 429 | fun remove_redundant_lhs lhss = | |
| 430 | let | |
| 431 | fun f lhs pat = if pattern_contains lhs pat | |
| 432 | then ((lhs, false), pat) | |
| 433 | else ((lhs, true), pattern_insert ctr_count lhs pat) | |
| 434 | in | |
| 435 | fold_map f lhss pattern_empty | |
| 436 | |> fst | |
| 437 | |> filter snd | |
| 438 | |> map fst | |
| 439 | end | |
| 440 | fun remove_redundant_rhs rhss = | |
| 441 | let | |
| 442 | fun f (lhs, rhs) pat = if pattern_contains [lhs] pat | |
| 443 | then (((lhs, rhs), false), pat) | |
| 444 | else (((lhs, rhs), true), pattern_insert ctr_count [lhs] pat) | |
| 445 | in | |
| 446 | map fst (filter snd (fold_map f rhss pattern_empty |> fst)) | |
| 447 | end | |
| 448 | val lhss3 = remove_redundant_lhs lhss2 | |
| 449 | ||
| 450 | (* Step 4 : Compute right hand side *) | |
| 451 | fun subs_fun f = fold_term | |
| 452 | Const | |
| 453 | (f o Free) | |
| 454 | Var | |
| 455 | Bound | |
| 456 | Abs | |
| 457 | (fn (x, y) => f x $ f y) | |
| 458 | fun find_rhss lhs = | |
| 459 | let | |
| 460 | fun f (lhs1, (pat, rhs)) = | |
| 461 | case match_all lhs1 lhs of NONE => NONE | |
| 462 | | SOME f => SOME (pat, subs_fun f rhs) | |
| 463 | in | |
| 464 | remove_redundant_rhs | |
| 465 | (map_filter f (lhss1 ~~ (lazy_pats ~~ rhss)) @ | |
| 466 | [(lazy_case_arg, list_comb (fun_t, lhs) |> abort)] | |
| 467 | ) | |
| 468 | end | |
| 469 | ||
| 470 | (* Step 5 : make_case of right hand side *) | |
| 471 | fun make_case ctxt case_arg cases = case cases of | |
| 472 | [(Free x, rhs)] => subs_fun (fn y => if y = Free x then case_arg else y) rhs | |
| 473 | | _ => Case_Translation.make_case | |
| 474 | ctxt | |
| 475 | Case_Translation.Warning | |
| 476 | Name.context | |
| 477 | case_arg | |
| 478 | cases | |
| 479 | val type_name_fun = add_type_names lazy_pats o type_name_fun | |
| 480 | val rhss3 = map ((make_case ctxt lazy_case_arg) o find_rhss) lhss3 | |
| 481 | in | |
| 482 | (lhss3, rhss3, type_name_fun) | |
| 483 | end; | |
| 484 | ||
| 485 | fun terms_to_case ctxt ctr_count (head : term) (lhss : term list list) | |
| 486 | (rhss : term list) (typ_list : typ list) (poss : (int * (string * int) list) list) = | |
| 487 | let | |
| 488 | val (lhss1, rhss1, def_frees, case_args, ctxt1) = replace_frees lhss rhss typ_list ctxt | |
| 489 | val exec_list = poss ~~ def_frees | |
| 490 | val (lhss2, rhss2, type_name_fun) = fold_rev | |
| 491 | (terms_to_case_at ctr_count ctxt1 head case_args) exec_list (lhss1, rhss1, I) | |
| 492 | fun make_eq_term (lhss, rhs) = (list_comb (head, lhss), rhs) | |
| 493 | |> HOLogic.mk_eq | |
| 494 | |> HOLogic.mk_Trueprop | |
| 495 | in | |
| 496 | (map make_eq_term (lhss2 ~~ rhss2), | |
| 497 | get_split_theorems ctxt1 (type_name_fun Symtab.empty), | |
| 498 | ctxt1) | |
| 499 | end; | |
| 500 | ||
| 69568 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 501 | |
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 502 | fun build_case_t elimination_strategy ctr_count head lhss rhss ctxt = | 
| 68155 | 503 | let | 
| 504 | val num_eqs = length lhss | |
| 505 | val _ = if length rhss = num_eqs andalso num_eqs > 0 then () | |
| 506 | else raise Fail | |
| 507 |         ("expected same number of left-hand sides as right-hand sides\n"
 | |
| 508 | ^ "and at least one equation") | |
| 509 | val n = length (hd lhss) | |
| 510 | val _ = if forall (fn m => length m = n) lhss then () | |
| 511 | else raise Fail "expected equal number of arguments" | |
| 512 | ||
| 69568 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 513 | fun to_coordinates (n, ts) = | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 514 | case elimination_strategy ctxt ts of | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 515 | [] => NONE | 
| 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 516 | | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list) | 
| 68155 | 517 | fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss | 
| 518 | val (typ_list, poss) = lhss | |
| 519 | |> Ctr_Sugar_Util.transpose | |
| 520 | |> map_index to_coordinates | |
| 68301 | 521 | |> map_filter (Option.map add_T) | 
| 68155 | 522 | |> flat | 
| 69568 
de09a7261120
new implementation for case_of_simps based on Code_Lazy's pattern matching elimination algorithm
 Andreas Lochbihler parents: 
68301diff
changeset | 523 | |> split_list | 
| 68155 | 524 | in | 
| 525 | if null poss then ([], [], ctxt) | |
| 526 | else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss | |
| 527 | end; | |
| 528 | ||
| 529 | fun tac ctxt {splits, intros, defs} =
 | |
| 530 | let | |
| 531 | val split_and_subst = | |
| 532 | split_tac ctxt splits | |
| 533 | THEN' REPEAT_ALL_NEW ( | |
| 534 |         resolve_tac ctxt [@{thm conjI}, @{thm allI}]
 | |
| 535 | ORELSE' | |
| 536 |         (resolve_tac ctxt [@{thm impI}] THEN' hyp_subst_tac_thin true ctxt))
 | |
| 537 | in | |
| 538 | (REPEAT_ALL_NEW split_and_subst ORELSE' K all_tac) | |
| 539 |     THEN' (K (Local_Defs.unfold_tac ctxt [@{thm missing_pattern_match_def}]))
 | |
| 540 | THEN' (K (Local_Defs.unfold_tac ctxt defs)) | |
| 541 |     THEN_ALL_NEW (SOLVED' (resolve_tac ctxt (@{thm refl} :: intros)))
 | |
| 542 | end; | |
| 543 | ||
| 544 | fun to_case _ _ _ [] = NONE | |
| 545 | | to_case ctxt replace_ctr ctr_count ths = | |
| 546 | let | |
| 547 | val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq | |
| 548 | fun import [] ctxt = ([], ctxt) | |
| 549 | | import (thm :: thms) ctxt = | |
| 550 | let | |
| 551 | val fun_ct = strip_eq #> fst #> head_of #> Logic.mk_term #> Thm.cterm_of ctxt | |
| 552 | val ct = fun_ct thm | |
| 553 | val cts = map fun_ct thms | |
| 554 | val pairs = map (fn s => (s,ct)) cts | |
| 555 | val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs) | |
| 556 | in | |
| 557 | Variable.import true (thm :: thms') ctxt |> apfst snd | |
| 558 | end | |
| 559 | ||
| 560 | val (iths, ctxt') = import ths ctxt | |
| 561 | val head = hd iths |> strip_eq |> fst |> head_of | |
| 562 | val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths | |
| 563 | ||
| 564 | fun hide_rhs ((pat, rhs), name) lthy = | |
| 565 | let | |
| 566 | val frees = fold Term.add_frees pat [] | |
| 567 | val abs_rhs = fold absfree frees rhs | |
| 568 | val (f, def, lthy') = case lthy | |
| 569 | |> Local_Defs.define [((Binding.name name, NoSyn), (Binding.empty_atts, abs_rhs))] of | |
| 570 | ([(f, (_, def))], lthy') => (f, def, lthy') | |
| 571 | | _ => raise Match | |
| 572 | in | |
| 573 | ((list_comb (f, map Free (rev frees)), def), lthy') | |
| 574 | end | |
| 575 | ||
| 576 | val rhs_names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs) | |
| 577 | val ((def_ts, def_thms), ctxt2) = | |
| 578 | fold_map hide_rhs (eqs ~~ rhs_names) ctxt' |> apfst split_list | |
| 579 | val (ts, split_thms, ctxt3) = build_case_t replace_ctr ctr_count head | |
| 580 | (map fst eqs) def_ts ctxt2 | |
| 581 | fun mk_thm t = Goal.prove ctxt3 [] [] t | |
| 582 |           (fn {context=ctxt, ...} => tac ctxt {splits=split_thms, intros=ths, defs=def_thms} 1)
 | |
| 583 | in | |
| 584 | if null ts then NONE | |
| 585 | else | |
| 586 | ts | |
| 587 | |> map mk_thm | |
| 588 | |> Proof_Context.export ctxt3 ctxt | |
| 589 | |> map (Goal.norm_result ctxt) | |
| 590 | |> SOME | |
| 591 | end; | |
| 592 | ||
| 593 | end |