blob: 122b3fcebf3fedd29ece099d1a64b2ef771e29f6 [file] [log] [blame]
Yi Kong878f9942023-12-13 12:55:04 +09001"""Provides shared memory for direct access across processes.
2
3The API of this package is currently provisional. Refer to the
4documentation for details.
5"""
6
7
8__all__ = [ 'SharedMemory', 'ShareableList' ]
9
10
11from functools import partial
12import mmap
13import os
14import errno
15import struct
16import secrets
17import types
18
19if os.name == "nt":
20 import _winapi
21 _USE_POSIX = False
22else:
23 import _posixshmem
24 _USE_POSIX = True
25
26
27_O_CREX = os.O_CREAT | os.O_EXCL
28
29# FreeBSD (and perhaps other BSDs) limit names to 14 characters.
30_SHM_SAFE_NAME_LENGTH = 14
31
32# Shared memory block name prefix
33if _USE_POSIX:
34 _SHM_NAME_PREFIX = '/psm_'
35else:
36 _SHM_NAME_PREFIX = 'wnsm_'
37
38
39def _make_filename():
40 "Create a random filename for the shared memory object."
41 # number of random bytes to use for name
42 nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2
43 assert nbytes >= 2, '_SHM_NAME_PREFIX too long'
44 name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes)
45 assert len(name) <= _SHM_SAFE_NAME_LENGTH
46 return name
47
48
49class SharedMemory:
50 """Creates a new shared memory block or attaches to an existing
51 shared memory block.
52
53 Every shared memory block is assigned a unique name. This enables
54 one process to create a shared memory block with a particular name
55 so that a different process can attach to that same shared memory
56 block using that same name.
57
58 As a resource for sharing data across processes, shared memory blocks
59 may outlive the original process that created them. When one process
60 no longer needs access to a shared memory block that might still be
61 needed by other processes, the close() method should be called.
62 When a shared memory block is no longer needed by any process, the
63 unlink() method should be called to ensure proper cleanup."""
64
65 # Defaults; enables close() and unlink() to run without errors.
66 _name = None
67 _fd = -1
68 _mmap = None
69 _buf = None
70 _flags = os.O_RDWR
71 _mode = 0o600
72 _prepend_leading_slash = True if _USE_POSIX else False
73
74 def __init__(self, name=None, create=False, size=0):
75 if not size >= 0:
76 raise ValueError("'size' must be a positive integer")
77 if create:
78 self._flags = _O_CREX | os.O_RDWR
79 if size == 0:
80 raise ValueError("'size' must be a positive number different from zero")
81 if name is None and not self._flags & os.O_EXCL:
82 raise ValueError("'name' can only be None if create=True")
83
84 if _USE_POSIX:
85
86 # POSIX Shared Memory
87
88 if name is None:
89 while True:
90 name = _make_filename()
91 try:
92 self._fd = _posixshmem.shm_open(
93 name,
94 self._flags,
95 mode=self._mode
96 )
97 except FileExistsError:
98 continue
99 self._name = name
100 break
101 else:
102 name = "/" + name if self._prepend_leading_slash else name
103 self._fd = _posixshmem.shm_open(
104 name,
105 self._flags,
106 mode=self._mode
107 )
108 self._name = name
109 try:
110 if create and size:
111 os.ftruncate(self._fd, size)
112 stats = os.fstat(self._fd)
113 size = stats.st_size
114 self._mmap = mmap.mmap(self._fd, size)
115 except OSError:
116 self.unlink()
117 raise
118
119 from .resource_tracker import register
120 register(self._name, "shared_memory")
121
122 else:
123
124 # Windows Named Shared Memory
125
126 if create:
127 while True:
128 temp_name = _make_filename() if name is None else name
129 # Create and reserve shared memory block with this name
130 # until it can be attached to by mmap.
131 h_map = _winapi.CreateFileMapping(
132 _winapi.INVALID_HANDLE_VALUE,
133 _winapi.NULL,
134 _winapi.PAGE_READWRITE,
135 (size >> 32) & 0xFFFFFFFF,
136 size & 0xFFFFFFFF,
137 temp_name
138 )
139 try:
140 last_error_code = _winapi.GetLastError()
141 if last_error_code == _winapi.ERROR_ALREADY_EXISTS:
142 if name is not None:
143 raise FileExistsError(
144 errno.EEXIST,
145 os.strerror(errno.EEXIST),
146 name,
147 _winapi.ERROR_ALREADY_EXISTS
148 )
149 else:
150 continue
151 self._mmap = mmap.mmap(-1, size, tagname=temp_name)
152 finally:
153 _winapi.CloseHandle(h_map)
154 self._name = temp_name
155 break
156
157 else:
158 self._name = name
159 # Dynamically determine the existing named shared memory
160 # block's size which is likely a multiple of mmap.PAGESIZE.
161 h_map = _winapi.OpenFileMapping(
162 _winapi.FILE_MAP_READ,
163 False,
164 name
165 )
166 try:
167 p_buf = _winapi.MapViewOfFile(
168 h_map,
169 _winapi.FILE_MAP_READ,
170 0,
171 0,
172 0
173 )
174 finally:
175 _winapi.CloseHandle(h_map)
176 size = _winapi.VirtualQuerySize(p_buf)
177 self._mmap = mmap.mmap(-1, size, tagname=name)
178
179 self._size = size
180 self._buf = memoryview(self._mmap)
181
182 def __del__(self):
183 try:
184 self.close()
185 except OSError:
186 pass
187
188 def __reduce__(self):
189 return (
190 self.__class__,
191 (
192 self.name,
193 False,
194 self.size,
195 ),
196 )
197
198 def __repr__(self):
199 return f'{self.__class__.__name__}({self.name!r}, size={self.size})'
200
201 @property
202 def buf(self):
203 "A memoryview of contents of the shared memory block."
204 return self._buf
205
206 @property
207 def name(self):
208 "Unique name that identifies the shared memory block."
209 reported_name = self._name
210 if _USE_POSIX and self._prepend_leading_slash:
211 if self._name.startswith("/"):
212 reported_name = self._name[1:]
213 return reported_name
214
215 @property
216 def size(self):
217 "Size in bytes."
218 return self._size
219
220 def close(self):
221 """Closes access to the shared memory from this instance but does
222 not destroy the shared memory block."""
223 if self._buf is not None:
224 self._buf.release()
225 self._buf = None
226 if self._mmap is not None:
227 self._mmap.close()
228 self._mmap = None
229 if _USE_POSIX and self._fd >= 0:
230 os.close(self._fd)
231 self._fd = -1
232
233 def unlink(self):
234 """Requests that the underlying shared memory block be destroyed.
235
236 In order to ensure proper cleanup of resources, unlink should be
237 called once (and only once) across all processes which have access
238 to the shared memory block."""
239 if _USE_POSIX and self._name:
240 from .resource_tracker import unregister
241 _posixshmem.shm_unlink(self._name)
242 unregister(self._name, "shared_memory")
243
244
245_encoding = "utf8"
246
247class ShareableList:
248 """Pattern for a mutable list-like object shareable via a shared
249 memory block. It differs from the built-in list type in that these
250 lists can not change their overall length (i.e. no append, insert,
251 etc.)
252
253 Because values are packed into a memoryview as bytes, the struct
254 packing format for any storable value must require no more than 8
255 characters to describe its format."""
256
257 # The shared memory area is organized as follows:
258 # - 8 bytes: number of items (N) as a 64-bit integer
259 # - (N + 1) * 8 bytes: offsets of each element from the start of the
260 # data area
261 # - K bytes: the data area storing item values (with encoding and size
262 # depending on their respective types)
263 # - N * 8 bytes: `struct` format string for each element
264 # - N bytes: index into _back_transforms_mapping for each element
265 # (for reconstructing the corresponding Python value)
266 _types_mapping = {
267 int: "q",
268 float: "d",
269 bool: "xxxxxxx?",
270 str: "%ds",
271 bytes: "%ds",
272 None.__class__: "xxxxxx?x",
273 }
274 _alignment = 8
275 _back_transforms_mapping = {
276 0: lambda value: value, # int, float, bool
277 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str
278 2: lambda value: value.rstrip(b'\x00'), # bytes
279 3: lambda _value: None, # None
280 }
281
282 @staticmethod
283 def _extract_recreation_code(value):
284 """Used in concert with _back_transforms_mapping to convert values
285 into the appropriate Python objects when retrieving them from
286 the list as well as when storing them."""
287 if not isinstance(value, (str, bytes, None.__class__)):
288 return 0
289 elif isinstance(value, str):
290 return 1
291 elif isinstance(value, bytes):
292 return 2
293 else:
294 return 3 # NoneType
295
296 def __init__(self, sequence=None, *, name=None):
297 if name is None or sequence is not None:
298 sequence = sequence or ()
299 _formats = [
300 self._types_mapping[type(item)]
301 if not isinstance(item, (str, bytes))
302 else self._types_mapping[type(item)] % (
303 self._alignment * (len(item) // self._alignment + 1),
304 )
305 for item in sequence
306 ]
307 self._list_len = len(_formats)
308 assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len
309 offset = 0
310 # The offsets of each list element into the shared memory's
311 # data area (0 meaning the start of the data area, not the start
312 # of the shared memory area).
313 self._allocated_offsets = [0]
314 for fmt in _formats:
315 offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
316 self._allocated_offsets.append(offset)
317 _recreation_codes = [
318 self._extract_recreation_code(item) for item in sequence
319 ]
320 requested_size = struct.calcsize(
321 "q" + self._format_size_metainfo +
322 "".join(_formats) +
323 self._format_packing_metainfo +
324 self._format_back_transform_codes
325 )
326
327 self.shm = SharedMemory(name, create=True, size=requested_size)
328 else:
329 self.shm = SharedMemory(name)
330
331 if sequence is not None:
332 _enc = _encoding
333 struct.pack_into(
334 "q" + self._format_size_metainfo,
335 self.shm.buf,
336 0,
337 self._list_len,
338 *(self._allocated_offsets)
339 )
340 struct.pack_into(
341 "".join(_formats),
342 self.shm.buf,
343 self._offset_data_start,
344 *(v.encode(_enc) if isinstance(v, str) else v for v in sequence)
345 )
346 struct.pack_into(
347 self._format_packing_metainfo,
348 self.shm.buf,
349 self._offset_packing_formats,
350 *(v.encode(_enc) for v in _formats)
351 )
352 struct.pack_into(
353 self._format_back_transform_codes,
354 self.shm.buf,
355 self._offset_back_transform_codes,
356 *(_recreation_codes)
357 )
358
359 else:
360 self._list_len = len(self) # Obtains size from offset 0 in buffer.
361 self._allocated_offsets = list(
362 struct.unpack_from(
363 self._format_size_metainfo,
364 self.shm.buf,
365 1 * 8
366 )
367 )
368
369 def _get_packing_format(self, position):
370 "Gets the packing format for a single value stored in the list."
371 position = position if position >= 0 else position + self._list_len
372 if (position >= self._list_len) or (self._list_len < 0):
373 raise IndexError("Requested position out of range.")
374
375 v = struct.unpack_from(
376 "8s",
377 self.shm.buf,
378 self._offset_packing_formats + position * 8
379 )[0]
380 fmt = v.rstrip(b'\x00')
381 fmt_as_str = fmt.decode(_encoding)
382
383 return fmt_as_str
384
385 def _get_back_transform(self, position):
386 "Gets the back transformation function for a single value."
387
388 if (position >= self._list_len) or (self._list_len < 0):
389 raise IndexError("Requested position out of range.")
390
391 transform_code = struct.unpack_from(
392 "b",
393 self.shm.buf,
394 self._offset_back_transform_codes + position
395 )[0]
396 transform_function = self._back_transforms_mapping[transform_code]
397
398 return transform_function
399
400 def _set_packing_format_and_transform(self, position, fmt_as_str, value):
401 """Sets the packing format and back transformation code for a
402 single value in the list at the specified position."""
403
404 if (position >= self._list_len) or (self._list_len < 0):
405 raise IndexError("Requested position out of range.")
406
407 struct.pack_into(
408 "8s",
409 self.shm.buf,
410 self._offset_packing_formats + position * 8,
411 fmt_as_str.encode(_encoding)
412 )
413
414 transform_code = self._extract_recreation_code(value)
415 struct.pack_into(
416 "b",
417 self.shm.buf,
418 self._offset_back_transform_codes + position,
419 transform_code
420 )
421
422 def __getitem__(self, position):
423 position = position if position >= 0 else position + self._list_len
424 try:
425 offset = self._offset_data_start + self._allocated_offsets[position]
426 (v,) = struct.unpack_from(
427 self._get_packing_format(position),
428 self.shm.buf,
429 offset
430 )
431 except IndexError:
432 raise IndexError("index out of range")
433
434 back_transform = self._get_back_transform(position)
435 v = back_transform(v)
436
437 return v
438
439 def __setitem__(self, position, value):
440 position = position if position >= 0 else position + self._list_len
441 try:
442 item_offset = self._allocated_offsets[position]
443 offset = self._offset_data_start + item_offset
444 current_format = self._get_packing_format(position)
445 except IndexError:
446 raise IndexError("assignment index out of range")
447
448 if not isinstance(value, (str, bytes)):
449 new_format = self._types_mapping[type(value)]
450 encoded_value = value
451 else:
452 allocated_length = self._allocated_offsets[position + 1] - item_offset
453
454 encoded_value = (value.encode(_encoding)
455 if isinstance(value, str) else value)
456 if len(encoded_value) > allocated_length:
457 raise ValueError("bytes/str item exceeds available storage")
458 if current_format[-1] == "s":
459 new_format = current_format
460 else:
461 new_format = self._types_mapping[str] % (
462 allocated_length,
463 )
464
465 self._set_packing_format_and_transform(
466 position,
467 new_format,
468 value
469 )
470 struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
471
472 def __reduce__(self):
473 return partial(self.__class__, name=self.shm.name), ()
474
475 def __len__(self):
476 return struct.unpack_from("q", self.shm.buf, 0)[0]
477
478 def __repr__(self):
479 return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})'
480
481 @property
482 def format(self):
483 "The struct packing format used by all currently stored items."
484 return "".join(
485 self._get_packing_format(i) for i in range(self._list_len)
486 )
487
488 @property
489 def _format_size_metainfo(self):
490 "The struct packing format used for the items' storage offsets."
491 return "q" * (self._list_len + 1)
492
493 @property
494 def _format_packing_metainfo(self):
495 "The struct packing format used for the items' packing formats."
496 return "8s" * self._list_len
497
498 @property
499 def _format_back_transform_codes(self):
500 "The struct packing format used for the items' back transforms."
501 return "b" * self._list_len
502
503 @property
504 def _offset_data_start(self):
505 # - 8 bytes for the list length
506 # - (N + 1) * 8 bytes for the element offsets
507 return (self._list_len + 2) * 8
508
509 @property
510 def _offset_packing_formats(self):
511 return self._offset_data_start + self._allocated_offsets[-1]
512
513 @property
514 def _offset_back_transform_codes(self):
515 return self._offset_packing_formats + self._list_len * 8
516
517 def count(self, value):
518 "L.count(value) -> integer -- return number of occurrences of value."
519
520 return sum(value == entry for entry in self)
521
522 def index(self, value):
523 """L.index(value) -> integer -- return first index of value.
524 Raises ValueError if the value is not present."""
525
526 for position, entry in enumerate(self):
527 if value == entry:
528 return position
529 else:
530 raise ValueError(f"{value!r} not in this container")
531
532 __class_getitem__ = classmethod(types.GenericAlias)