Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement server-side upsert for Postgres #335

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ jobs:
pip install -e ".[dev]"
- name: Run SQLite tests
env:
DATABASE_URI: 'sqlite:///:memory:'
run: |
make test
- name: Run PostgreSQL tests
env:
DATABASE_URI: 'postgresql+psycopg2://postgres:postgres@postgres:${{ job.services.postgres.ports[5432] }}/dataset'
DATABASE_URL: 'sqlite:///:memory:'
run: |
make test
# - name: Run PostgreSQL tests
# env:
# DATABASE_URL: 'postgresql+psycopg2://postgres:[email protected]:${{ job.services.postgres.ports[5432] }}/dataset'
# run: |
# make test
- name: Run MariaDB tests
env:
DATABASE_URI: 'mysql+pymysql://mariadb:mariadb@mariadb:${{ job.services.mariadb.ports[3306] }}/dataset?charset=utf8'
DATABASE_URL: 'mysql+pymysql://mariadb:mariadb@127.0.0.1:${{ job.services.mariadb.ports[3306] }}/dataset?charset=utf8'
run: |
make test
- name: Run flake8 to lint
Expand Down
4 changes: 2 additions & 2 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
the `ensure_schema` argument. It can also be overridden in a lot of the
data manipulation methods using the `ensure` flag.

.. _SQLAlchemy Engine URL: http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine
.. _DB connection timeout: http://docs.sqlalchemy.org/en/latest/core/pooling.html#setting-pool-recycle
.. _SQLAlchemy Engine URL: http://docs.sqlalchemy.org/en/latest/core/engines.html#sqlalchemy.create_engine # noqa
.. _DB connection timeout: http://docs.sqlalchemy.org/en/latest/core/pooling.html#setting-pool-recycle # noqa
"""
if url is None:
url = os.environ.get('DATABASE_URL', 'sqlite://')
Expand Down
14 changes: 3 additions & 11 deletions dataset/chunked.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools


class InvalidCallback(ValueError):
Expand All @@ -10,12 +9,12 @@ def __init__(self, table, chunksize, callback):
self.queue = []
self.table = table
self.chunksize = chunksize
if callback and not callable(callback):
if callback is not None and not callable(callback):
raise InvalidCallback
self.callback = callback

def flush(self):
self.queue.clear()
self.queue = []

def _queue_add(self, item):
self.queue.append(item)
Expand All @@ -41,17 +40,12 @@ class ChunkedInsert(_Chunker):
"""

def __init__(self, table, chunksize=1000, callback=None):
self.fields = set()
super().__init__(table, chunksize, callback)

def insert(self, item):
self.fields.update(item.keys())
super()._queue_add(item)

def flush(self):
for item in self.queue:
for field in self.fields:
item[field] = item.get(field)
if self.callback is not None:
self.callback(self.queue)
self.table.insert_many(self.queue)
Expand Down Expand Up @@ -79,7 +73,5 @@ def update(self, item):
def flush(self):
if self.callback is not None:
self.callback(self.queue)
self.queue.sort(key=dict.keys)
for fields, items in itertools.groupby(self.queue, key=dict.keys):
self.table.update_many(list(items), self.keys)
self.table.update_many(self.queue, self.keys)
super().flush()
3 changes: 2 additions & 1 deletion dataset/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, url, schema=None, reflect_metadata=True,

self.schema = schema
self.engine = create_engine(url, **engine_kwargs)
self.types = Types(self.engine.dialect.name)
self.is_postgres = self.engine.dialect.name == 'postgresql'
self.types = Types(is_postgres=self.is_postgres)
self.url = url
self.row_type = row_type
self.ensure_schema = ensure_schema
Expand Down
164 changes: 74 additions & 90 deletions dataset/table.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import logging
import warnings
import threading
from banal import ensure_list

from sqlalchemy import func, select, false
from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import bindparam, ClauseElement
from sqlalchemy.dialects.postgresql import insert as upsert
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError

from dataset.types import Types
from dataset.util import index_name, ensure_tuple
from dataset.util import DatasetException, ResultIter, QUERY_STEP
from dataset.util import normalize_table_name, pad_chunk_columns
from dataset.util import normalize_table_name, index_name
from dataset.util import normalize_column_name, normalize_column_key


Expand Down Expand Up @@ -153,28 +154,8 @@ def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
rows = [dict(name='Dolly')] * 10000
table.insert_many(rows)
"""
# Sync table before inputting rows.
sync_row = {}
for row in rows:
# Only get non-existing columns.
sync_keys = list(sync_row.keys())
for key in [k for k in row.keys() if k not in sync_keys]:
# Get a sample of the new column(s) from the row.
sync_row[key] = row[key]
self._sync_columns(sync_row, ensure, types=types)

# Get columns name list to be used for padding later.
columns = sync_row.keys()

chunk = []
for index, row in enumerate(rows):
chunk.append(row)

# Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
self.table.insert().execute(chunk)
chunk = []
for chunk, _ in self._row_chunks(rows, chunk_size, ensure, types):
self.db.executable.execute(self.table.insert(), chunk)

def update(self, row, keys, ensure=None, types=None, return_count=False):
"""Update a row in the table.
Expand All @@ -193,17 +174,13 @@ def update(self, row, keys, ensure=None, types=None, return_count=False):
be created based on the settings of ``ensure`` and ``types``, matching
the behavior of :py:meth:`insert() <dataset.Table.insert>`.
"""
row = self._sync_columns(row, ensure, types=types)
args, row = self._keys_to_args(row, keys)
clause = self._args_to_clause(args)
if not len(row):
return self.count(clause)
stmt = self.table.update(whereclause=clause, values=row)
rp = self.db.executable.execute(stmt)
if rp.supports_sane_rowcount():
return rp.rowcount
if return_count:
rowcount = self.update_many([row], keys, ensure=ensure, types=types)
if rowcount is None and return_count:
row = self._sync_columns(row, ensure, types=types)
args, row = self._keys_to_args(row, keys)
clause = self._args_to_clause(args)
return self.count(clause)
return rowcount

def update_many(self, rows, keys, chunk_size=1000, ensure=None,
types=None):
Expand All @@ -216,32 +193,25 @@ def update_many(self, rows, keys, chunk_size=1000, ensure=None,
See :py:meth:`update() <dataset.Table.update>` for details on
the other parameters.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]

chunk = []
columns = []
for index, row in enumerate(rows):
chunk.append(row)
for col in row.keys():
if col not in columns:
columns.append(col)

# bindparam requires names to not conflict (cannot be "id" for id)
for key in keys:
row['_%s' % key] = row[key]

# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam('_%s' % k) for k in keys]
stmt = self.table.update(
whereclause=and_(*cl),
values={
col: bindparam(col, required=False) for col in columns
}
)
self.db.executable.execute(stmt, chunk)
chunk = []
keys = [self._get_column_name(k) for k in ensure_list(keys)]
bindings = [(k, 'u_%s' % k) for k in keys]
rowcount = 0
for chunk, cols in self._row_chunks(rows, chunk_size, ensure, types):
for row in chunk:
# bindparam requires names to not conflict
# (cannot be "id" for id)
for key, alias in bindings:
row[alias] = row.get(key, None)

cl = [self.table.c[k] == bindparam(a) for k, a in bindings]
values = {c: bindparam(c, required=False) for c in cols}
stmt = self.table.update(whereclause=and_(*cl), values=values)
rp = self.db.executable.execute(stmt, chunk)
if rp.supports_sane_rowcount():
rowcount += rp.rowcount
else:
rowcount = None
return rowcount

def upsert(self, row, keys, ensure=None, types=None):
"""An UPSERT is a smart combination of insert and update.
Expand All @@ -253,13 +223,7 @@ def upsert(self, row, keys, ensure=None, types=None):
data = dict(id=10, title='I am a banana!')
table.upsert(data, ['id'])
"""
row = self._sync_columns(row, ensure, types=types)
if self._check_ensure(ensure):
self.create_index(keys)
row_count = self.update(row, keys, ensure=False, return_count=True)
if row_count == 0:
return self.insert(row, ensure=False)
return True
return self.upsert_many([row], keys, ensure=ensure, types=types)

def upsert_many(self, rows, keys, chunk_size=1000, ensure=None,
types=None):
Expand All @@ -270,24 +234,21 @@ def upsert_many(self, rows, keys, chunk_size=1000, ensure=None,
See :py:meth:`upsert() <dataset.Table.upsert>` and
:py:meth:`insert_many() <dataset.Table.insert_many>`.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]

to_insert = []
to_update = []
for row in rows:
if self.find_one(**{key: row.get(key) for key in keys}):
# Row exists - update it.
to_update.append(row)
else:
# Row doesn't exist - insert it.
to_insert.append(row)

# Insert non-existing rows.
self.insert_many(to_insert, chunk_size, ensure, types)

# Update existing rows.
self.update_many(to_update, keys, chunk_size, ensure, types)
if self.db.is_postgres:
return self._upsert_postgres(rows, keys, chunk_size, ensure, types)
for row in ensure_list(rows):
cnt = self.update(row, keys, ensure=ensure, return_count=True)
if cnt == 0:
self.insert(row, ensure=ensure)

def _upsert_postgres(self, rows, keys, chunk_size, ensure, types):
"""Postgres ON CONFLICT UPDATE for INSERT statements."""
keys = [self._get_column_name(k) for k in ensure_list(keys)]
for chunk, cols in self._row_chunks(rows, chunk_size, ensure, types):
stmt = upsert(self.table).values(chunk)
set_ = {c: stmt.excluded[c] for c in cols if c not in keys}
stmt = stmt.on_conflict_do_update(index_elements=keys, set_=set_)
self.db.executable.execute(stmt)

def delete(self, *clauses, **filters):
"""Delete rows from the table.
Expand All @@ -307,6 +268,30 @@ def delete(self, *clauses, **filters):
rp = self.db.executable.execute(stmt)
return rp.rowcount > 0

def _row_chunks(self, rows, chunk_size, ensure, types):
"""Normalise a set of rows for a bulk operation, with table
adaptation."""
# Sync table before inputting rows.
sync_row = {}
for row in rows:
for key in row.keys():
if sync_row.get(key) is None:
sync_row[key] = row[key]
self._sync_columns(sync_row, ensure, types=types)
columns = tuple(sync_row.keys())

chunk = []
for row in rows:
row = dict(row)
for column in sync_row.keys():
row.setdefault(column, None)
chunk.append(row)
if len(chunk) >= chunk_size:
yield chunk, columns
chunk = []
if len(chunk):
yield chunk, columns

def _reflect_table(self):
"""Load the tables definition from the database."""
with self.db.lock:
Expand Down Expand Up @@ -442,7 +427,7 @@ def _args_to_clause(self, args, clauses=()):

def _args_to_order_by(self, order_by):
orderings = []
for ordering in ensure_tuple(order_by):
for ordering in ensure_list(order_by):
if ordering is None:
continue
column = ordering.lstrip('-')
Expand All @@ -456,8 +441,7 @@ def _args_to_order_by(self, order_by):
return orderings

def _keys_to_args(self, row, keys):
keys = ensure_tuple(keys)
keys = [self._get_column_name(k) for k in keys]
keys = [self._get_column_name(k) for k in ensure_list(keys)]
row = row.copy()
args = {k: row.pop(k, None) for k in keys}
return args, row
Expand Down Expand Up @@ -557,7 +541,7 @@ def create_index(self, columns, name=None, **kw):

table.create_index(['name', 'country'])
"""
columns = [self._get_column_name(c) for c in ensure_tuple(columns)]
columns = [self._get_column_name(c) for c in ensure_list(columns)]
with self.db.lock:
if not self.exists:
raise DatasetException("Table has not been created yet.")
Expand Down
10 changes: 2 additions & 8 deletions dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@ class Types(object):
date = Date
datetime = DateTime

def __init__(self, dialect=None):
self._dialect = dialect

@property
def json(self):
if self._dialect is not None and self._dialect == 'postgresql':
return JSONB
return JSON
def __init__(self, is_postgres=None):
self.json = JSONB if is_postgres else JSON

def guess(self, sample):
"""Given a single sample, guess the column type for the field.
Expand Down
19 changes: 0 additions & 19 deletions dataset/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from hashlib import sha1
from urllib.parse import urlparse
from collections import OrderedDict
from collections.abc import Iterable
from sqlalchemy.exc import ResourceClosedError

QUERY_STEP = 1000
Expand Down Expand Up @@ -108,21 +107,3 @@ def index_name(table, columns):
sig = '||'.join(columns)
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
return 'ix_%s_%s' % (table, key)


def ensure_tuple(obj):
"""Try and make the given argument into a tuple."""
if obj is None:
return tuple()
if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
return tuple(obj)
return obj,


def pad_chunk_columns(chunk, columns):
"""Given a set of items to be inserted, make sure they all have the
same columns by padding columns with None if they are missing."""
for record in chunk:
for column in columns:
record.setdefault(column, None)
return chunk
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
zip_safe=False,
install_requires=[
'sqlalchemy >= 1.3.2',
'alembic >= 0.6.2'
'alembic >= 0.6.2',
'banal >= 1.0.1',
],
extras_require={
'dev': [
Expand Down
Loading