Skip to content

Commit

Permalink
Remove sshtunnel
Browse files Browse the repository at this point in the history
  • Loading branch information
ВашÐViktoria Shepard committed Oct 18, 2023
1 parent 46eb92a commit 263ff9c
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 59 deletions.
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
"port-for>=0.4",
"six>=1.9.0",
"psutil",
"packaging",
"sshtunnel"
"packaging"
]

# Add compatibility enum class
Expand Down
11 changes: 10 additions & 1 deletion testgres/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ def __str__(self):
if self.out:
msg.append(u'----\n{}'.format(self.out))

return six.text_type('\n').join(msg)
return self.convert_and_join(msg)

@staticmethod
def convert_and_join(msg_list):
# Convert each byte element in the list to str
str_list = [six.text_type(item, 'utf-8') if isinstance(item, bytes) else six.text_type(item) for item in
msg_list]

# Join the list into a single string with the specified delimiter
return six.text_type('\n').join(str_list)


@six.python_2_unicode_compatible
Expand Down
2 changes: 1 addition & 1 deletion testgres/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def pgbench(self,
# should be the last one
_params.append(dbname)

proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True)
proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, get_process=True)

return proc

Expand Down
13 changes: 9 additions & 4 deletions testgres/operations/local_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from distutils.spawn import find_executable
from distutils import rmtree


CMD_TIMEOUT_SEC = 60
error_markers = [b'error', b'Permission denied', b'fatal']

Expand All @@ -37,7 +36,8 @@ def __init__(self, conn_params=None):
# Command execution
def exec_command(self, cmd, wait_exit=False, verbose=False,
expect_error=False, encoding=None, shell=False, text=False,
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None):
input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
get_process=None, timeout=None):
"""
Execute a command in a subprocess.
Expand Down Expand Up @@ -69,9 +69,14 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
stdout=stdout,
stderr=stderr,
)
if proc:
if get_process:
return process
result, error = process.communicate(input)

try:
result, error = process.communicate(input, timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))
exit_status = process.returncode

error_found = exit_status != 0 or any(marker in error for marker in error_markers)
Expand Down
91 changes: 43 additions & 48 deletions testgres/operations/remote_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import os
import subprocess
import tempfile
import time

import sshtunnel
# we support both pg8000 and psycopg2
try:
import psycopg2 as pglib
except ImportError:
try:
import pg8000 as pglib
except ImportError:
raise ImportError("You must have psycopg2 or pg8000 modules installed")

from ..exceptions import ExecUtilException

from .os_ops import OsOperations, ConnectionParams
from .os_ops import pglib

sshtunnel.SSH_TIMEOUT = 5.0
sshtunnel.TUNNEL_TIMEOUT = 5.0

ConsoleEncoding = locale.getdefaultlocale()[1]
if not ConsoleEncoding:
Expand Down Expand Up @@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams):
self.remote = True
self.username = conn_params.username or self.get_user()
self.add_known_host(self.host)
self.tunnel_process = None

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.close_tunnel()
self.close_ssh_tunnel()

def close_tunnel(self):
if getattr(self, 'tunnel', None):
self.tunnel.stop(force=True)
start_time = time.time()
while self.tunnel.is_active:
if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT:
break
time.sleep(0.5)
def establish_ssh_tunnel(self, local_port, remote_port):
"""
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
"""
ssh_cmd = ['-N', '-L', f"{local_port}:localhost:{remote_port}"]
self.tunnel_process = self.exec_command(ssh_cmd, get_process=True, timeout=300)

def close_ssh_tunnel(self):
if hasattr(self, 'tunnel_process'):
self.tunnel_process.terminate()
self.tunnel_process.wait()
del self.tunnel_process
else:
print("No active tunnel to close.")

def add_known_host(self, host):
cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host, os.getlogin())
Expand All @@ -78,21 +87,29 @@ def add_known_host(self, host):
raise ExecUtilException(message="Failed to add %s to known_hosts. Error: %s" % (host, str(e)), command=cmd,
exit_code=e.returncode, out=e.stderr)

def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False,
def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None,
stderr=None, proc=None):
stderr=None, get_process=None, timeout=None):
"""
Execute a command in the SSH session.
Args:
- cmd (str): The command to be executed.
"""
ssh_cmd = []
if isinstance(cmd, str):
ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key, cmd]
elif isinstance(cmd, list):
ssh_cmd = ['ssh', f"{self.username}@{self.host}", '-i', self.ssh_key] + cmd
process = subprocess.Popen(ssh_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if get_process:
return process

try:
result, error = process.communicate(input, timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
raise ExecUtilException("Command timed out after {} seconds.".format(timeout))

result, error = process.communicate(input)
exit_status = process.returncode

if encoding:
Expand Down Expand Up @@ -372,41 +389,19 @@ def get_process_children(self, pid):
raise ExecUtilException(f"Error in getting process children. Error: {result.stderr}")

# Database control
def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None):
def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
"""
Connects to a PostgreSQL database on the remote system.
Args:
- dbname (str): The name of the database to connect to.
- user (str): The username for the database connection.
- password (str, optional): The password for the database connection. Defaults to None.
- host (str, optional): The IP address of the remote system. Defaults to "localhost".
- port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
This function establishes a connection to a PostgreSQL database on the remote system using the specified
parameters. It returns a connection object that can be used to interact with the database.
Established SSH tunnel and Connects to a PostgreSQL
"""
self.close_tunnel()
self.tunnel = sshtunnel.open_tunnel(
(self.host, 22), # Remote server IP and SSH port
ssh_username=self.username,
ssh_pkey=self.ssh_key,
remote_bind_address=(self.host, port), # PostgreSQL server IP and PostgreSQL port
local_bind_address=('localhost', 0)
# Local machine IP and available port (0 means it will pick any available port)
)
self.tunnel.start()

self.establish_ssh_tunnel(local_port=port, remote_port=5432)
try:
# Use localhost and self.tunnel.local_bind_port to connect
conn = pglib.connect(
host='localhost', # Connect to localhost
port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel
host=host,
port=port,
database=dbname,
user=user or self.username,
password=password
user=user,
password=password,
)

return conn
except Exception as e:
self.tunnel.stop()
raise ExecUtilException("Could not create db tunnel. {}".format(e))
raise Exception(f"Could not connect to the database. Error: {e}")
7 changes: 4 additions & 3 deletions tests/test_simple_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,10 @@ def test_pgbench(self):
options=['-q']).pgbench_run(time=2)

# run TPC-B benchmark
out = node.pgbench(stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
options=['-T3'])
proc = node.pgbench(stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
options=['-T3'])
out = proc.communicate()[0]
self.assertTrue(b'tps = ' in out)

def test_pg_config(self):
Expand Down

0 comments on commit 263ff9c

Please sign in to comment.