Yi Kong | 878f994 | 2023-12-13 12:55:04 +0900 | [diff] [blame^] | 1 | """Provides shared memory for direct access across processes. |
| 2 | |
| 3 | The API of this package is currently provisional. Refer to the |
| 4 | documentation for details. |
| 5 | """ |
| 6 | |
| 7 | |
| 8 | __all__ = [ 'SharedMemory', 'ShareableList' ] |
| 9 | |
| 10 | |
| 11 | from functools import partial |
| 12 | import mmap |
| 13 | import os |
| 14 | import errno |
| 15 | import struct |
| 16 | import secrets |
| 17 | import types |
| 18 | |
| 19 | if os.name == "nt": |
| 20 | import _winapi |
| 21 | _USE_POSIX = False |
| 22 | else: |
| 23 | import _posixshmem |
| 24 | _USE_POSIX = True |
| 25 | |
| 26 | |
| 27 | _O_CREX = os.O_CREAT | os.O_EXCL |
| 28 | |
| 29 | # FreeBSD (and perhaps other BSDs) limit names to 14 characters. |
| 30 | _SHM_SAFE_NAME_LENGTH = 14 |
| 31 | |
| 32 | # Shared memory block name prefix |
| 33 | if _USE_POSIX: |
| 34 | _SHM_NAME_PREFIX = '/psm_' |
| 35 | else: |
| 36 | _SHM_NAME_PREFIX = 'wnsm_' |
| 37 | |
| 38 | |
| 39 | def _make_filename(): |
| 40 | "Create a random filename for the shared memory object." |
| 41 | # number of random bytes to use for name |
| 42 | nbytes = (_SHM_SAFE_NAME_LENGTH - len(_SHM_NAME_PREFIX)) // 2 |
| 43 | assert nbytes >= 2, '_SHM_NAME_PREFIX too long' |
| 44 | name = _SHM_NAME_PREFIX + secrets.token_hex(nbytes) |
| 45 | assert len(name) <= _SHM_SAFE_NAME_LENGTH |
| 46 | return name |
| 47 | |
| 48 | |
| 49 | class SharedMemory: |
| 50 | """Creates a new shared memory block or attaches to an existing |
| 51 | shared memory block. |
| 52 | |
| 53 | Every shared memory block is assigned a unique name. This enables |
| 54 | one process to create a shared memory block with a particular name |
| 55 | so that a different process can attach to that same shared memory |
| 56 | block using that same name. |
| 57 | |
| 58 | As a resource for sharing data across processes, shared memory blocks |
| 59 | may outlive the original process that created them. When one process |
| 60 | no longer needs access to a shared memory block that might still be |
| 61 | needed by other processes, the close() method should be called. |
| 62 | When a shared memory block is no longer needed by any process, the |
| 63 | unlink() method should be called to ensure proper cleanup.""" |
| 64 | |
| 65 | # Defaults; enables close() and unlink() to run without errors. |
| 66 | _name = None |
| 67 | _fd = -1 |
| 68 | _mmap = None |
| 69 | _buf = None |
| 70 | _flags = os.O_RDWR |
| 71 | _mode = 0o600 |
| 72 | _prepend_leading_slash = True if _USE_POSIX else False |
| 73 | |
| 74 | def __init__(self, name=None, create=False, size=0): |
| 75 | if not size >= 0: |
| 76 | raise ValueError("'size' must be a positive integer") |
| 77 | if create: |
| 78 | self._flags = _O_CREX | os.O_RDWR |
| 79 | if size == 0: |
| 80 | raise ValueError("'size' must be a positive number different from zero") |
| 81 | if name is None and not self._flags & os.O_EXCL: |
| 82 | raise ValueError("'name' can only be None if create=True") |
| 83 | |
| 84 | if _USE_POSIX: |
| 85 | |
| 86 | # POSIX Shared Memory |
| 87 | |
| 88 | if name is None: |
| 89 | while True: |
| 90 | name = _make_filename() |
| 91 | try: |
| 92 | self._fd = _posixshmem.shm_open( |
| 93 | name, |
| 94 | self._flags, |
| 95 | mode=self._mode |
| 96 | ) |
| 97 | except FileExistsError: |
| 98 | continue |
| 99 | self._name = name |
| 100 | break |
| 101 | else: |
| 102 | name = "/" + name if self._prepend_leading_slash else name |
| 103 | self._fd = _posixshmem.shm_open( |
| 104 | name, |
| 105 | self._flags, |
| 106 | mode=self._mode |
| 107 | ) |
| 108 | self._name = name |
| 109 | try: |
| 110 | if create and size: |
| 111 | os.ftruncate(self._fd, size) |
| 112 | stats = os.fstat(self._fd) |
| 113 | size = stats.st_size |
| 114 | self._mmap = mmap.mmap(self._fd, size) |
| 115 | except OSError: |
| 116 | self.unlink() |
| 117 | raise |
| 118 | |
| 119 | from .resource_tracker import register |
| 120 | register(self._name, "shared_memory") |
| 121 | |
| 122 | else: |
| 123 | |
| 124 | # Windows Named Shared Memory |
| 125 | |
| 126 | if create: |
| 127 | while True: |
| 128 | temp_name = _make_filename() if name is None else name |
| 129 | # Create and reserve shared memory block with this name |
| 130 | # until it can be attached to by mmap. |
| 131 | h_map = _winapi.CreateFileMapping( |
| 132 | _winapi.INVALID_HANDLE_VALUE, |
| 133 | _winapi.NULL, |
| 134 | _winapi.PAGE_READWRITE, |
| 135 | (size >> 32) & 0xFFFFFFFF, |
| 136 | size & 0xFFFFFFFF, |
| 137 | temp_name |
| 138 | ) |
| 139 | try: |
| 140 | last_error_code = _winapi.GetLastError() |
| 141 | if last_error_code == _winapi.ERROR_ALREADY_EXISTS: |
| 142 | if name is not None: |
| 143 | raise FileExistsError( |
| 144 | errno.EEXIST, |
| 145 | os.strerror(errno.EEXIST), |
| 146 | name, |
| 147 | _winapi.ERROR_ALREADY_EXISTS |
| 148 | ) |
| 149 | else: |
| 150 | continue |
| 151 | self._mmap = mmap.mmap(-1, size, tagname=temp_name) |
| 152 | finally: |
| 153 | _winapi.CloseHandle(h_map) |
| 154 | self._name = temp_name |
| 155 | break |
| 156 | |
| 157 | else: |
| 158 | self._name = name |
| 159 | # Dynamically determine the existing named shared memory |
| 160 | # block's size which is likely a multiple of mmap.PAGESIZE. |
| 161 | h_map = _winapi.OpenFileMapping( |
| 162 | _winapi.FILE_MAP_READ, |
| 163 | False, |
| 164 | name |
| 165 | ) |
| 166 | try: |
| 167 | p_buf = _winapi.MapViewOfFile( |
| 168 | h_map, |
| 169 | _winapi.FILE_MAP_READ, |
| 170 | 0, |
| 171 | 0, |
| 172 | 0 |
| 173 | ) |
| 174 | finally: |
| 175 | _winapi.CloseHandle(h_map) |
| 176 | size = _winapi.VirtualQuerySize(p_buf) |
| 177 | self._mmap = mmap.mmap(-1, size, tagname=name) |
| 178 | |
| 179 | self._size = size |
| 180 | self._buf = memoryview(self._mmap) |
| 181 | |
| 182 | def __del__(self): |
| 183 | try: |
| 184 | self.close() |
| 185 | except OSError: |
| 186 | pass |
| 187 | |
| 188 | def __reduce__(self): |
| 189 | return ( |
| 190 | self.__class__, |
| 191 | ( |
| 192 | self.name, |
| 193 | False, |
| 194 | self.size, |
| 195 | ), |
| 196 | ) |
| 197 | |
| 198 | def __repr__(self): |
| 199 | return f'{self.__class__.__name__}({self.name!r}, size={self.size})' |
| 200 | |
| 201 | @property |
| 202 | def buf(self): |
| 203 | "A memoryview of contents of the shared memory block." |
| 204 | return self._buf |
| 205 | |
| 206 | @property |
| 207 | def name(self): |
| 208 | "Unique name that identifies the shared memory block." |
| 209 | reported_name = self._name |
| 210 | if _USE_POSIX and self._prepend_leading_slash: |
| 211 | if self._name.startswith("/"): |
| 212 | reported_name = self._name[1:] |
| 213 | return reported_name |
| 214 | |
| 215 | @property |
| 216 | def size(self): |
| 217 | "Size in bytes." |
| 218 | return self._size |
| 219 | |
| 220 | def close(self): |
| 221 | """Closes access to the shared memory from this instance but does |
| 222 | not destroy the shared memory block.""" |
| 223 | if self._buf is not None: |
| 224 | self._buf.release() |
| 225 | self._buf = None |
| 226 | if self._mmap is not None: |
| 227 | self._mmap.close() |
| 228 | self._mmap = None |
| 229 | if _USE_POSIX and self._fd >= 0: |
| 230 | os.close(self._fd) |
| 231 | self._fd = -1 |
| 232 | |
| 233 | def unlink(self): |
| 234 | """Requests that the underlying shared memory block be destroyed. |
| 235 | |
| 236 | In order to ensure proper cleanup of resources, unlink should be |
| 237 | called once (and only once) across all processes which have access |
| 238 | to the shared memory block.""" |
| 239 | if _USE_POSIX and self._name: |
| 240 | from .resource_tracker import unregister |
| 241 | _posixshmem.shm_unlink(self._name) |
| 242 | unregister(self._name, "shared_memory") |
| 243 | |
| 244 | |
| 245 | _encoding = "utf8" |
| 246 | |
| 247 | class ShareableList: |
| 248 | """Pattern for a mutable list-like object shareable via a shared |
| 249 | memory block. It differs from the built-in list type in that these |
| 250 | lists can not change their overall length (i.e. no append, insert, |
| 251 | etc.) |
| 252 | |
| 253 | Because values are packed into a memoryview as bytes, the struct |
| 254 | packing format for any storable value must require no more than 8 |
| 255 | characters to describe its format.""" |
| 256 | |
| 257 | # The shared memory area is organized as follows: |
| 258 | # - 8 bytes: number of items (N) as a 64-bit integer |
| 259 | # - (N + 1) * 8 bytes: offsets of each element from the start of the |
| 260 | # data area |
| 261 | # - K bytes: the data area storing item values (with encoding and size |
| 262 | # depending on their respective types) |
| 263 | # - N * 8 bytes: `struct` format string for each element |
| 264 | # - N bytes: index into _back_transforms_mapping for each element |
| 265 | # (for reconstructing the corresponding Python value) |
| 266 | _types_mapping = { |
| 267 | int: "q", |
| 268 | float: "d", |
| 269 | bool: "xxxxxxx?", |
| 270 | str: "%ds", |
| 271 | bytes: "%ds", |
| 272 | None.__class__: "xxxxxx?x", |
| 273 | } |
| 274 | _alignment = 8 |
| 275 | _back_transforms_mapping = { |
| 276 | 0: lambda value: value, # int, float, bool |
| 277 | 1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str |
| 278 | 2: lambda value: value.rstrip(b'\x00'), # bytes |
| 279 | 3: lambda _value: None, # None |
| 280 | } |
| 281 | |
| 282 | @staticmethod |
| 283 | def _extract_recreation_code(value): |
| 284 | """Used in concert with _back_transforms_mapping to convert values |
| 285 | into the appropriate Python objects when retrieving them from |
| 286 | the list as well as when storing them.""" |
| 287 | if not isinstance(value, (str, bytes, None.__class__)): |
| 288 | return 0 |
| 289 | elif isinstance(value, str): |
| 290 | return 1 |
| 291 | elif isinstance(value, bytes): |
| 292 | return 2 |
| 293 | else: |
| 294 | return 3 # NoneType |
| 295 | |
| 296 | def __init__(self, sequence=None, *, name=None): |
| 297 | if name is None or sequence is not None: |
| 298 | sequence = sequence or () |
| 299 | _formats = [ |
| 300 | self._types_mapping[type(item)] |
| 301 | if not isinstance(item, (str, bytes)) |
| 302 | else self._types_mapping[type(item)] % ( |
| 303 | self._alignment * (len(item) // self._alignment + 1), |
| 304 | ) |
| 305 | for item in sequence |
| 306 | ] |
| 307 | self._list_len = len(_formats) |
| 308 | assert sum(len(fmt) <= 8 for fmt in _formats) == self._list_len |
| 309 | offset = 0 |
| 310 | # The offsets of each list element into the shared memory's |
| 311 | # data area (0 meaning the start of the data area, not the start |
| 312 | # of the shared memory area). |
| 313 | self._allocated_offsets = [0] |
| 314 | for fmt in _formats: |
| 315 | offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1]) |
| 316 | self._allocated_offsets.append(offset) |
| 317 | _recreation_codes = [ |
| 318 | self._extract_recreation_code(item) for item in sequence |
| 319 | ] |
| 320 | requested_size = struct.calcsize( |
| 321 | "q" + self._format_size_metainfo + |
| 322 | "".join(_formats) + |
| 323 | self._format_packing_metainfo + |
| 324 | self._format_back_transform_codes |
| 325 | ) |
| 326 | |
| 327 | self.shm = SharedMemory(name, create=True, size=requested_size) |
| 328 | else: |
| 329 | self.shm = SharedMemory(name) |
| 330 | |
| 331 | if sequence is not None: |
| 332 | _enc = _encoding |
| 333 | struct.pack_into( |
| 334 | "q" + self._format_size_metainfo, |
| 335 | self.shm.buf, |
| 336 | 0, |
| 337 | self._list_len, |
| 338 | *(self._allocated_offsets) |
| 339 | ) |
| 340 | struct.pack_into( |
| 341 | "".join(_formats), |
| 342 | self.shm.buf, |
| 343 | self._offset_data_start, |
| 344 | *(v.encode(_enc) if isinstance(v, str) else v for v in sequence) |
| 345 | ) |
| 346 | struct.pack_into( |
| 347 | self._format_packing_metainfo, |
| 348 | self.shm.buf, |
| 349 | self._offset_packing_formats, |
| 350 | *(v.encode(_enc) for v in _formats) |
| 351 | ) |
| 352 | struct.pack_into( |
| 353 | self._format_back_transform_codes, |
| 354 | self.shm.buf, |
| 355 | self._offset_back_transform_codes, |
| 356 | *(_recreation_codes) |
| 357 | ) |
| 358 | |
| 359 | else: |
| 360 | self._list_len = len(self) # Obtains size from offset 0 in buffer. |
| 361 | self._allocated_offsets = list( |
| 362 | struct.unpack_from( |
| 363 | self._format_size_metainfo, |
| 364 | self.shm.buf, |
| 365 | 1 * 8 |
| 366 | ) |
| 367 | ) |
| 368 | |
| 369 | def _get_packing_format(self, position): |
| 370 | "Gets the packing format for a single value stored in the list." |
| 371 | position = position if position >= 0 else position + self._list_len |
| 372 | if (position >= self._list_len) or (self._list_len < 0): |
| 373 | raise IndexError("Requested position out of range.") |
| 374 | |
| 375 | v = struct.unpack_from( |
| 376 | "8s", |
| 377 | self.shm.buf, |
| 378 | self._offset_packing_formats + position * 8 |
| 379 | )[0] |
| 380 | fmt = v.rstrip(b'\x00') |
| 381 | fmt_as_str = fmt.decode(_encoding) |
| 382 | |
| 383 | return fmt_as_str |
| 384 | |
| 385 | def _get_back_transform(self, position): |
| 386 | "Gets the back transformation function for a single value." |
| 387 | |
| 388 | if (position >= self._list_len) or (self._list_len < 0): |
| 389 | raise IndexError("Requested position out of range.") |
| 390 | |
| 391 | transform_code = struct.unpack_from( |
| 392 | "b", |
| 393 | self.shm.buf, |
| 394 | self._offset_back_transform_codes + position |
| 395 | )[0] |
| 396 | transform_function = self._back_transforms_mapping[transform_code] |
| 397 | |
| 398 | return transform_function |
| 399 | |
| 400 | def _set_packing_format_and_transform(self, position, fmt_as_str, value): |
| 401 | """Sets the packing format and back transformation code for a |
| 402 | single value in the list at the specified position.""" |
| 403 | |
| 404 | if (position >= self._list_len) or (self._list_len < 0): |
| 405 | raise IndexError("Requested position out of range.") |
| 406 | |
| 407 | struct.pack_into( |
| 408 | "8s", |
| 409 | self.shm.buf, |
| 410 | self._offset_packing_formats + position * 8, |
| 411 | fmt_as_str.encode(_encoding) |
| 412 | ) |
| 413 | |
| 414 | transform_code = self._extract_recreation_code(value) |
| 415 | struct.pack_into( |
| 416 | "b", |
| 417 | self.shm.buf, |
| 418 | self._offset_back_transform_codes + position, |
| 419 | transform_code |
| 420 | ) |
| 421 | |
| 422 | def __getitem__(self, position): |
| 423 | position = position if position >= 0 else position + self._list_len |
| 424 | try: |
| 425 | offset = self._offset_data_start + self._allocated_offsets[position] |
| 426 | (v,) = struct.unpack_from( |
| 427 | self._get_packing_format(position), |
| 428 | self.shm.buf, |
| 429 | offset |
| 430 | ) |
| 431 | except IndexError: |
| 432 | raise IndexError("index out of range") |
| 433 | |
| 434 | back_transform = self._get_back_transform(position) |
| 435 | v = back_transform(v) |
| 436 | |
| 437 | return v |
| 438 | |
| 439 | def __setitem__(self, position, value): |
| 440 | position = position if position >= 0 else position + self._list_len |
| 441 | try: |
| 442 | item_offset = self._allocated_offsets[position] |
| 443 | offset = self._offset_data_start + item_offset |
| 444 | current_format = self._get_packing_format(position) |
| 445 | except IndexError: |
| 446 | raise IndexError("assignment index out of range") |
| 447 | |
| 448 | if not isinstance(value, (str, bytes)): |
| 449 | new_format = self._types_mapping[type(value)] |
| 450 | encoded_value = value |
| 451 | else: |
| 452 | allocated_length = self._allocated_offsets[position + 1] - item_offset |
| 453 | |
| 454 | encoded_value = (value.encode(_encoding) |
| 455 | if isinstance(value, str) else value) |
| 456 | if len(encoded_value) > allocated_length: |
| 457 | raise ValueError("bytes/str item exceeds available storage") |
| 458 | if current_format[-1] == "s": |
| 459 | new_format = current_format |
| 460 | else: |
| 461 | new_format = self._types_mapping[str] % ( |
| 462 | allocated_length, |
| 463 | ) |
| 464 | |
| 465 | self._set_packing_format_and_transform( |
| 466 | position, |
| 467 | new_format, |
| 468 | value |
| 469 | ) |
| 470 | struct.pack_into(new_format, self.shm.buf, offset, encoded_value) |
| 471 | |
| 472 | def __reduce__(self): |
| 473 | return partial(self.__class__, name=self.shm.name), () |
| 474 | |
| 475 | def __len__(self): |
| 476 | return struct.unpack_from("q", self.shm.buf, 0)[0] |
| 477 | |
| 478 | def __repr__(self): |
| 479 | return f'{self.__class__.__name__}({list(self)}, name={self.shm.name!r})' |
| 480 | |
| 481 | @property |
| 482 | def format(self): |
| 483 | "The struct packing format used by all currently stored items." |
| 484 | return "".join( |
| 485 | self._get_packing_format(i) for i in range(self._list_len) |
| 486 | ) |
| 487 | |
| 488 | @property |
| 489 | def _format_size_metainfo(self): |
| 490 | "The struct packing format used for the items' storage offsets." |
| 491 | return "q" * (self._list_len + 1) |
| 492 | |
| 493 | @property |
| 494 | def _format_packing_metainfo(self): |
| 495 | "The struct packing format used for the items' packing formats." |
| 496 | return "8s" * self._list_len |
| 497 | |
| 498 | @property |
| 499 | def _format_back_transform_codes(self): |
| 500 | "The struct packing format used for the items' back transforms." |
| 501 | return "b" * self._list_len |
| 502 | |
| 503 | @property |
| 504 | def _offset_data_start(self): |
| 505 | # - 8 bytes for the list length |
| 506 | # - (N + 1) * 8 bytes for the element offsets |
| 507 | return (self._list_len + 2) * 8 |
| 508 | |
| 509 | @property |
| 510 | def _offset_packing_formats(self): |
| 511 | return self._offset_data_start + self._allocated_offsets[-1] |
| 512 | |
| 513 | @property |
| 514 | def _offset_back_transform_codes(self): |
| 515 | return self._offset_packing_formats + self._list_len * 8 |
| 516 | |
| 517 | def count(self, value): |
| 518 | "L.count(value) -> integer -- return number of occurrences of value." |
| 519 | |
| 520 | return sum(value == entry for entry in self) |
| 521 | |
| 522 | def index(self, value): |
| 523 | """L.index(value) -> integer -- return first index of value. |
| 524 | Raises ValueError if the value is not present.""" |
| 525 | |
| 526 | for position, entry in enumerate(self): |
| 527 | if value == entry: |
| 528 | return position |
| 529 | else: |
| 530 | raise ValueError(f"{value!r} not in this container") |
| 531 | |
| 532 | __class_getitem__ = classmethod(types.GenericAlias) |