| 26170 |      1 | (*  Title:      HOL/Library/Array.thy
 | 
|  |      2 |     ID:         $Id$
 | 
|  |      3 |     Author:     John Matthews, Galois Connections; Alexander Krauss, Lukas Bulwahn & Florian Haftmann, TU Muenchen
 | 
|  |      4 | *)
 | 
|  |      5 | 
 | 
|  |      6 | header {* Monadic arrays *}
 | 
|  |      7 | 
 | 
|  |      8 | theory Array
 | 
| 26182 |      9 | imports Heap_Monad Code_Index
 | 
| 26170 |     10 | begin
 | 
|  |     11 | 
 | 
|  |     12 | subsection {* Primitives *}
 | 
|  |     13 | 
 | 
|  |     14 | definition
 | 
|  |     15 |   new :: "nat \<Rightarrow> 'a\<Colon>heap \<Rightarrow> 'a array Heap" where
 | 
|  |     16 |   [code del]: "new n x = Heap_Monad.heap (Heap.array n x)"
 | 
|  |     17 | 
 | 
|  |     18 | definition
 | 
|  |     19 |   of_list :: "'a\<Colon>heap list \<Rightarrow> 'a array Heap" where
 | 
|  |     20 |   [code del]: "of_list xs = Heap_Monad.heap (Heap.array_of_list xs)"
 | 
|  |     21 | 
 | 
|  |     22 | definition
 | 
|  |     23 |   length :: "'a\<Colon>heap array \<Rightarrow> nat Heap" where
 | 
|  |     24 |   [code del]: "length arr = Heap_Monad.heap (\<lambda>h. (Heap.length arr h, h))"
 | 
|  |     25 | 
 | 
|  |     26 | definition
 | 
|  |     27 |   nth :: "'a\<Colon>heap array \<Rightarrow> nat \<Rightarrow> 'a Heap"
 | 
|  |     28 | where
 | 
|  |     29 |   [code del]: "nth a i = (do len \<leftarrow> length a;
 | 
|  |     30 |                  (if i < len
 | 
|  |     31 |                      then Heap_Monad.heap (\<lambda>h. (get_array a h ! i, h))
 | 
|  |     32 |                      else raise (''array lookup: index out of range''))
 | 
|  |     33 |               done)"
 | 
|  |     34 | 
 | 
|  |     35 | -- {* FIXME adjustion for List theory *}
 | 
|  |     36 | no_syntax
 | 
|  |     37 |   nth  :: "'a list \<Rightarrow> nat \<Rightarrow> 'a" (infixl "!" 100)
 | 
|  |     38 | 
 | 
|  |     39 | abbreviation
 | 
|  |     40 |   nth_list :: "'a list \<Rightarrow> nat \<Rightarrow> 'a" (infixl "!" 100)
 | 
|  |     41 | where
 | 
|  |     42 |   "nth_list \<equiv> List.nth"
 | 
|  |     43 | 
 | 
|  |     44 | definition
 | 
|  |     45 |   upd :: "nat \<Rightarrow> 'a \<Rightarrow> 'a\<Colon>heap array \<Rightarrow> 'a\<Colon>heap array Heap"
 | 
|  |     46 | where
 | 
|  |     47 |   [code del]: "upd i x a = (do len \<leftarrow> length a;
 | 
|  |     48 |                       (if i < len
 | 
| 26719 |     49 |                            then Heap_Monad.heap (\<lambda>h. (a, Heap.upd a i x h))
 | 
|  |     50 |                            else raise (''array update: index out of range''))
 | 
| 26170 |     51 |                    done)" 
 | 
|  |     52 | 
 | 
|  |     53 | lemma upd_return:
 | 
|  |     54 |   "upd i x a \<guillemotright> return a = upd i x a"
 | 
| 26719 |     55 | proof (rule Heap_eqI)
 | 
|  |     56 |   fix h
 | 
|  |     57 |   obtain len h' where "Heap_Monad.execute (Array.length a) h = (len, h')"
 | 
|  |     58 |     by (cases "Heap_Monad.execute (Array.length a) h")
 | 
|  |     59 |   then show "Heap_Monad.execute (upd i x a \<guillemotright> return a) h = Heap_Monad.execute (upd i x a) h"
 | 
|  |     60 |     by (auto simp add: upd_def bindM_def run_drop split: sum.split)
 | 
|  |     61 | qed
 | 
| 26170 |     62 | 
 | 
|  |     63 | 
 | 
|  |     64 | subsection {* Derivates *}
 | 
|  |     65 | 
 | 
|  |     66 | definition
 | 
|  |     67 |   map_entry :: "nat \<Rightarrow> ('a\<Colon>heap \<Rightarrow> 'a) \<Rightarrow> 'a array \<Rightarrow> 'a array Heap"
 | 
|  |     68 | where
 | 
|  |     69 |   "map_entry i f a = (do
 | 
|  |     70 |      x \<leftarrow> nth a i;
 | 
|  |     71 |      upd i (f x) a
 | 
|  |     72 |    done)"
 | 
|  |     73 | 
 | 
|  |     74 | definition
 | 
|  |     75 |   swap :: "nat \<Rightarrow> 'a \<Rightarrow> 'a\<Colon>heap array \<Rightarrow> 'a Heap"
 | 
|  |     76 | where
 | 
|  |     77 |   "swap i x a = (do
 | 
|  |     78 |      y \<leftarrow> nth a i;
 | 
|  |     79 |      upd i x a;
 | 
|  |     80 |      return x
 | 
|  |     81 |    done)"
 | 
|  |     82 | 
 | 
|  |     83 | definition
 | 
|  |     84 |   make :: "nat \<Rightarrow> (nat \<Rightarrow> 'a\<Colon>heap) \<Rightarrow> 'a array Heap"
 | 
|  |     85 | where
 | 
|  |     86 |   "make n f = of_list (map f [0 ..< n])"
 | 
|  |     87 | 
 | 
|  |     88 | definition
 | 
|  |     89 |   freeze :: "'a\<Colon>heap array \<Rightarrow> 'a list Heap"
 | 
|  |     90 | where
 | 
|  |     91 |   "freeze a = (do
 | 
|  |     92 |      n \<leftarrow> length a;
 | 
|  |     93 |      mapM (nth a) [0..<n]
 | 
|  |     94 |    done)"
 | 
|  |     95 | 
 | 
|  |     96 | definition
 | 
|  |     97 |   map :: "('a\<Colon>heap \<Rightarrow> 'a) \<Rightarrow> 'a array \<Rightarrow> 'a array Heap"
 | 
|  |     98 | where
 | 
|  |     99 |   "map f a = (do
 | 
|  |    100 |      n \<leftarrow> length a;
 | 
|  |    101 |      foldM (\<lambda>n. map_entry n f) [0..<n] a
 | 
|  |    102 |    done)"
 | 
|  |    103 | 
 | 
|  |    104 | hide (open) const new map -- {* avoid clashed with some popular names *}
 | 
|  |    105 | 
 | 
|  |    106 | 
 | 
|  |    107 | subsection {* Properties *}
 | 
|  |    108 | 
 | 
|  |    109 | lemma array_make [code func]:
 | 
|  |    110 |   "Array.new n x = make n (\<lambda>_. x)"
 | 
|  |    111 |   by (induct n) (simp_all add: make_def new_def Heap_Monad.heap_def
 | 
|  |    112 |     monad_simp array_of_list_replicate [symmetric]
 | 
|  |    113 |     map_replicate_trivial replicate_append_same
 | 
|  |    114 |     of_list_def)
 | 
|  |    115 | 
 | 
|  |    116 | lemma array_of_list_make [code func]:
 | 
|  |    117 |   "of_list xs = make (List.length xs) (\<lambda>n. xs ! n)"
 | 
|  |    118 |   unfolding make_def map_nth ..
 | 
|  |    119 | 
 | 
| 26182 |    120 | 
 | 
|  |    121 | subsection {* Code generator setup *}
 | 
|  |    122 | 
 | 
|  |    123 | subsubsection {* Logical intermediate layer *}
 | 
|  |    124 | 
 | 
|  |    125 | definition new' where
 | 
|  |    126 |   [code del]: "new' = Array.new o nat_of_index"
 | 
|  |    127 | hide (open) const new'
 | 
|  |    128 | lemma [code func]:
 | 
|  |    129 |   "Array.new = Array.new' o index_of_nat"
 | 
|  |    130 |   by (simp add: new'_def o_def)
 | 
|  |    131 | 
 | 
|  |    132 | definition of_list' where
 | 
|  |    133 |   [code del]: "of_list' i xs = Array.of_list (take (nat_of_index i) xs)"
 | 
|  |    134 | hide (open) const of_list'
 | 
|  |    135 | lemma [code func]:
 | 
|  |    136 |   "Array.of_list xs = Array.of_list' (index_of_nat (List.length xs)) xs"
 | 
|  |    137 |   by (simp add: of_list'_def)
 | 
|  |    138 | 
 | 
|  |    139 | definition make' where
 | 
|  |    140 |   [code del]: "make' i f = Array.make (nat_of_index i) (f o index_of_nat)"
 | 
|  |    141 | hide (open) const make'
 | 
|  |    142 | lemma [code func]:
 | 
|  |    143 |   "Array.make n f = Array.make' (index_of_nat n) (f o nat_of_index)"
 | 
|  |    144 |   by (simp add: make'_def o_def)
 | 
|  |    145 | 
 | 
|  |    146 | definition length' where
 | 
|  |    147 |   [code del]: "length' = Array.length \<guillemotright>== liftM index_of_nat"
 | 
|  |    148 | hide (open) const length'
 | 
|  |    149 | lemma [code func]:
 | 
|  |    150 |   "Array.length = Array.length' \<guillemotright>== liftM nat_of_index"
 | 
|  |    151 |   by (simp add: length'_def monad_simp',
 | 
|  |    152 |     simp add: liftM_def comp_def monad_simp,
 | 
|  |    153 |     simp add: monad_simp')
 | 
|  |    154 | 
 | 
|  |    155 | definition nth' where
 | 
|  |    156 |   [code del]: "nth' a = Array.nth a o nat_of_index"
 | 
|  |    157 | hide (open) const nth'
 | 
|  |    158 | lemma [code func]:
 | 
|  |    159 |   "Array.nth a n = Array.nth' a (index_of_nat n)"
 | 
|  |    160 |   by (simp add: nth'_def)
 | 
|  |    161 | 
 | 
|  |    162 | definition upd' where
 | 
|  |    163 |   [code del]: "upd' a i x = Array.upd (nat_of_index i) x a \<guillemotright> return ()"
 | 
|  |    164 | hide (open) const upd'
 | 
|  |    165 | lemma [code func]:
 | 
|  |    166 |   "Array.upd i x a = Array.upd' a (index_of_nat i) x \<guillemotright> return a"
 | 
| 26743 |    167 |   by (simp add: upd'_def monad_simp upd_return)
 | 
| 26182 |    168 | 
 | 
|  |    169 | 
 | 
|  |    170 | subsubsection {* SML *}
 | 
|  |    171 | 
 | 
|  |    172 | code_type array (SML "_/ array")
 | 
|  |    173 | code_const Array (SML "raise/ (Fail/ \"bare Array\")")
 | 
| 26752 |    174 | code_const Array.new' (SML "(fn/ ()/ =>/ Array.array/ ((_),/ (_)))")
 | 
|  |    175 | code_const Array.of_list (SML "(fn/ ()/ =>/ Array.fromList/ _)")
 | 
|  |    176 | code_const Array.make' (SML "(fn/ ()/ =>/ Array.tabulate/ ((_),/ (_)))")
 | 
|  |    177 | code_const Array.length' (SML "(fn/ ()/ =>/ Array.length/ _)")
 | 
|  |    178 | code_const Array.nth' (SML "(fn/ ()/ =>/ Array.sub/ ((_),/ (_)))")
 | 
|  |    179 | code_const Array.upd' (SML "(fn/ ()/ =>/ Array.update/ ((_),/ (_),/ (_)))")
 | 
| 26182 |    180 | 
 | 
|  |    181 | code_reserved SML Array
 | 
|  |    182 | 
 | 
|  |    183 | 
 | 
|  |    184 | subsubsection {* OCaml *}
 | 
|  |    185 | 
 | 
|  |    186 | code_type array (OCaml "_/ array")
 | 
|  |    187 | code_const Array (OCaml "failwith/ \"bare Array\"")
 | 
| 26752 |    188 | code_const Array.new' (OCaml "(fn/ ()/ =>/ Array.make/ _/ _)")
 | 
|  |    189 | code_const Array.of_list (OCaml "(fn/ ()/ =>/ Array.of'_list/ _)")
 | 
|  |    190 | code_const Array.make' (OCaml "(fn/ ()/ =>/ Array.init/ _/ _)")
 | 
|  |    191 | code_const Array.length' (OCaml "(fn/ ()/ =>/ Array.length/ _)")
 | 
|  |    192 | code_const Array.nth' (OCaml "(fn/ ()/ =>/ Array.get/ _/ _)")
 | 
|  |    193 | code_const Array.upd' (OCaml "(fn/ ()/ =>/ Array.set/ _/ _/ _)")
 | 
| 26182 |    194 | 
 | 
|  |    195 | code_reserved OCaml Array
 | 
|  |    196 | 
 | 
|  |    197 | 
 | 
|  |    198 | subsubsection {* Haskell *}
 | 
|  |    199 | 
 | 
|  |    200 | code_type array (Haskell "STArray '_s _")
 | 
|  |    201 | code_const Array (Haskell "error/ \"bare Array\"")
 | 
|  |    202 | code_const Array.new' (Haskell "newArray/ (0,/ _)")
 | 
|  |    203 | code_const Array.of_list' (Haskell "newListArray/ (0,/ _)")
 | 
|  |    204 | code_const Array.length' (Haskell "length")
 | 
|  |    205 | code_const Array.nth' (Haskell "readArray")
 | 
|  |    206 | code_const Array.upd' (Haskell "writeArray")
 | 
|  |    207 | 
 | 
| 26170 |    208 | end
 |