[utils] Use locked_file for sanitize_open (#1066)

Authored by: jakeogh
This commit is contained in:
Justin Keogh 2022-02-05 10:45:51 +00:00 committed by GitHub
parent f1657a98cb
commit a3125791c7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -665,7 +665,7 @@ def sanitize_open(filename, open_mode):
import msvcrt import msvcrt
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
return (sys.stdout.buffer if hasattr(sys.stdout, 'buffer') else sys.stdout, filename) return (sys.stdout.buffer if hasattr(sys.stdout, 'buffer') else sys.stdout, filename)
stream = open(encodeFilename(filename), open_mode) stream = locked_file(filename, open_mode, block=False).open()
return (stream, filename) return (stream, filename)
except (IOError, OSError) as err: except (IOError, OSError) as err:
if err.errno in (errno.EACCES,): if err.errno in (errno.EACCES,):
@ -677,7 +677,7 @@ def sanitize_open(filename, open_mode):
raise raise
else: else:
# An exception here should be caught in the caller # An exception here should be caught in the caller
stream = open(encodeFilename(alt_filename), open_mode) stream = locked_file(filename, open_mode, block=False).open()
return (stream, alt_filename) return (stream, alt_filename)
@ -2115,7 +2115,7 @@ class OVERLAPPED(ctypes.Structure):
whole_low = 0xffffffff whole_low = 0xffffffff
whole_high = 0x7fffffff whole_high = 0x7fffffff
def _lock_file(f, exclusive): def _lock_file(f, exclusive, block): # todo: block unused on win32
overlapped = OVERLAPPED() overlapped = OVERLAPPED()
overlapped.Offset = 0 overlapped.Offset = 0
overlapped.OffsetHigh = 0 overlapped.OffsetHigh = 0
@ -2138,15 +2138,19 @@ def _unlock_file(f):
try: try:
import fcntl import fcntl
def _lock_file(f, exclusive): def _lock_file(f, exclusive, block):
fcntl.flock(f, fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH) fcntl.flock(f,
fcntl.LOCK_SH if not exclusive
else fcntl.LOCK_EX if block
else fcntl.LOCK_EX | fcntl.LOCK_NB)
def _unlock_file(f): def _unlock_file(f):
fcntl.flock(f, fcntl.LOCK_UN) fcntl.flock(f, fcntl.LOCK_UN)
except ImportError: except ImportError:
UNSUPPORTED_MSG = 'file locking is not supported on this platform' UNSUPPORTED_MSG = 'file locking is not supported on this platform'
def _lock_file(f, exclusive): def _lock_file(f, exclusive, block):
raise IOError(UNSUPPORTED_MSG) raise IOError(UNSUPPORTED_MSG)
def _unlock_file(f): def _unlock_file(f):
@ -2154,15 +2158,16 @@ def _unlock_file(f):
class locked_file(object): class locked_file(object):
def __init__(self, filename, mode, encoding=None): def __init__(self, filename, mode, block=True, encoding=None):
assert mode in ['r', 'a', 'w'] assert mode in ['r', 'rb', 'a', 'ab', 'w', 'wb']
self.f = io.open(filename, mode, encoding=encoding) self.f = io.open(filename, mode, encoding=encoding)
self.mode = mode self.mode = mode
self.block = block
def __enter__(self): def __enter__(self):
exclusive = self.mode != 'r' exclusive = 'r' not in self.mode
try: try:
_lock_file(self.f, exclusive) _lock_file(self.f, exclusive, self.block)
except IOError: except IOError:
self.f.close() self.f.close()
raise raise
@ -2183,6 +2188,15 @@ def write(self, *args):
def read(self, *args): def read(self, *args):
return self.f.read(*args) return self.f.read(*args)
def flush(self):
self.f.flush()
def open(self):
return self.__enter__()
def close(self, *args):
self.__exit__(self, *args, value=False, traceback=False)
def get_filesystem_encoding(): def get_filesystem_encoding():
encoding = sys.getfilesystemencoding() encoding = sys.getfilesystemencoding()