Skip to content

Commit

Permalink
Fix initdb error on Windows (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
demonolock authored Dec 19, 2023
1 parent 846c05f commit 1a2f6da
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 122 deletions.
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@
readme = f.read()

setup(
version='1.9.2',
version='1.9.3',
name='testgres',
packages=['testgres', 'testgres.operations'],
packages=['testgres', 'testgres.operations', 'testgres.helpers'],
description='Testing utility for PostgreSQL and its extensions',
url='https://github.com/postgrespro/testgres',
long_description=readme,
long_description_content_type='text/markdown',
license='PostgreSQL',
author='Ildar Musin',
author_email='[email protected]',
author='Postgres Professional',
author_email='[email protected]',
keywords=['test', 'testing', 'postgresql'],
install_requires=install_requires,
classifiers=[],
Expand Down
4 changes: 3 additions & 1 deletion testgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
from .operations.local_ops import LocalOperations
from .operations.remote_ops import RemoteOperations

from .helpers.port_manager import PortManager

__all__ = [
"get_new_node",
"get_remote_node",
Expand All @@ -62,6 +64,6 @@
"XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat",
"PostgresNode", "NodeApp",
"reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version",
"First", "Any",
"First", "Any", "PortManager",
"OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams"
]
Empty file added testgres/helpers/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions testgres/helpers/port_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import socket
import random
from typing import Set, Iterable, Optional


class PortForException(Exception):
pass


class PortManager:
def __init__(self, ports_range=(1024, 65535)):
self.ports_range = ports_range

@staticmethod
def is_port_free(port: int) -> bool:
"""Check if a port is free to use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("", port))
return True
except OSError:
return False

def find_free_port(self, ports: Optional[Set[int]] = None, exclude_ports: Optional[Iterable[int]] = None) -> int:
"""Return a random unused port number."""
if ports is None:
ports = set(range(1024, 65535))

if exclude_ports is None:
exclude_ports = set()

ports.difference_update(set(exclude_ports))

sampled_ports = random.sample(tuple(ports), min(len(ports), 100))

for port in sampled_ports:
if self.is_port_free(port):
return port

raise PortForException("Can't select a port")
8 changes: 4 additions & 4 deletions testgres/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def status(self):
"-D", self.data_dir,
"status"
] # yapf: disable
status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True)
if 'does not exist' in err:
status_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
if error and 'does not exist' in error:
return NodeStatus.Uninitialized
elif 'no server running' in out:
return NodeStatus.Stopped
Expand Down Expand Up @@ -717,7 +717,7 @@ def start(self, params=[], wait=True):

try:
exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
if 'does not exist' in error:
if error and 'does not exist' in error:
raise Exception
except Exception as e:
msg = 'Cannot start node'
Expand Down Expand Up @@ -791,7 +791,7 @@ def restart(self, params=[]):

try:
error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True)
if 'could not start server' in error:
if error and 'could not start server' in error:
raise ExecUtilException
except ExecUtilException as e:
msg = 'Cannot restart node'
Expand Down
133 changes: 74 additions & 59 deletions testgres/operations/local_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import psutil

from ..exceptions import ExecUtilException
from .os_ops import ConnectionParams, OsOperations
from .os_ops import pglib
from .os_ops import ConnectionParams, OsOperations, pglib, get_default_encoding

try:
from shutil import which as find_executable
Expand All @@ -22,6 +21,14 @@
error_markers = [b'error', b'Permission denied', b'fatal']


def has_errors(output):
if output:
if isinstance(output, str):
output = output.encode(get_default_encoding())
return any(marker in output for marker in error_markers)
return False


class LocalOperations(OsOperations):
def __init__(self, conn_params=None):
if conn_params is None:
Expand All @@ -33,72 +40,80 @@ def __init__(self, conn_params=None):
self.remote = False
self.username = conn_params.username or self.get_user()

# Command execution
def exec_command(self, cmd, wait_exit=False, verbose=False,
expect_error=False, encoding=None, shell=False, text=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
get_process=None, timeout=None):
"""
Execute a command in a subprocess.
Args:
- cmd: The command to execute.
- wait_exit: Whether to wait for the subprocess to exit before returning.
- verbose: Whether to return verbose output.
- expect_error: Whether to raise an error if the subprocess exits with an error status.
- encoding: The encoding to use for decoding the subprocess output.
- shell: Whether to use shell when executing the subprocess.
- text: Whether to return str instead of bytes for the subprocess output.
- input: The input to pass to the subprocess.
- stdout: The stdout to use for the subprocess.
- stderr: The stderr to use for the subprocess.
- proc: The process to use for subprocess creation.
:return: The output of the subprocess.
"""
if os.name == 'nt':
with tempfile.NamedTemporaryFile() as buf:
process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT)
process.communicate()
buf.seek(0)
result = buf.read().decode(encoding)
return result
else:
@staticmethod
def _raise_exec_exception(message, command, exit_code, output):
"""Raise an ExecUtilException."""
raise ExecUtilException(message=message.format(output),
command=command,
exit_code=exit_code,
out=output)

@staticmethod
def _process_output(encoding, temp_file_path):
"""Process the output of a command from a temporary file."""
with open(temp_file_path, 'rb') as temp_file:
output = temp_file.read()
if encoding:
output = output.decode(encoding)
return output, None # In Windows stderr writing in stdout

def _run_command(self, cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding):
"""Execute a command and return the process and its output."""
if os.name == 'nt' and stdout is None: # Windows
with tempfile.NamedTemporaryFile(mode='w+b', delete=False) as temp_file:
stdout = temp_file
stderr = subprocess.STDOUT
process = subprocess.Popen(
cmd,
shell=shell,
stdin=stdin or subprocess.PIPE if input is not None else None,
stdout=stdout,
stderr=stderr,
)
if get_process:
return process, None, None
temp_file_path = temp_file.name

# Wait process finished
process.wait()

output, error = self._process_output(encoding, temp_file_path)
return process, output, error
else: # Other OS
process = subprocess.Popen(
cmd,
shell=shell,
stdout=stdout,
stderr=stderr,
stdin=stdin or subprocess.PIPE if input is not None else None,
stdout=stdout or subprocess.PIPE,
stderr=stderr or subprocess.PIPE,
)
if get_process:
return process

return process, None, None
try:
result, error = process.communicate(input, timeout=timeout)
output, error = process.communicate(input=input.encode(encoding) if input else None, timeout=timeout)
if encoding:
output = output.decode(encoding)
error = error.decode(encoding)
return process, output, error
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
exit_status = process.returncode

error_found = exit_status != 0 or any(marker in error for marker in error_markers)

if encoding:
result = result.decode(encoding)
error = error.decode(encoding)

if expect_error:
raise Exception(result, error)

if exit_status != 0 or error_found:
if exit_status == 0:
exit_status = 1
raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error),
command=cmd,
exit_code=exit_status,
out=result)
if verbose:
return exit_status, result, error
else:
return result
def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=False,
text=False, input=None, stdin=None, stdout=None, stderr=None, get_process=False, timeout=None):
"""
Execute a command in a subprocess and handle the output based on the provided parameters.
"""
process, output, error = self._run_command(cmd, shell, input, stdin, stdout, stderr, get_process, timeout, encoding)
if get_process:
return process
if process.returncode != 0 or (has_errors(error) and not expect_error):
self._raise_exec_exception('Utility exited with non-zero code. Error `{}`', cmd, process.returncode, error)

if verbose:
return process.returncode, output, error
else:
return output

# Environment setup
def environ(self, var_name):
Expand Down Expand Up @@ -210,7 +225,7 @@ def read(self, filename, encoding=None, binary=False):
if binary:
return content
if isinstance(content, bytes):
return content.decode(encoding or 'utf-8')
return content.decode(encoding or get_default_encoding())
return content

def readlines(self, filename, num_lines=0, binary=False, encoding=None):
Expand Down
8 changes: 7 additions & 1 deletion testgres/operations/os_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import locale

try:
import psycopg2 as pglib # noqa: F401
except ImportError:
Expand All @@ -14,6 +16,10 @@ def __init__(self, host='127.0.0.1', ssh_key=None, username=None):
self.username = username


def get_default_encoding():
return locale.getdefaultlocale()[1] or 'UTF-8'


class OsOperations:
def __init__(self, username=None):
self.ssh_key = None
Expand Down Expand Up @@ -75,7 +81,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
def touch(self, filename):
raise NotImplementedError()

def read(self, filename):
def read(self, filename, encoding, binary):
raise NotImplementedError()

def readlines(self, filename):
Expand Down
Loading

0 comments on commit 1a2f6da

Please sign in to comment.