Skip to content

Commit

Permalink
Fix load_many to account for columns that use 'db_field'
Browse files Browse the repository at this point in the history
  • Loading branch information
pacejackson committed Mar 27, 2020
1 parent 9beb70e commit 5f9952f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
20 changes: 17 additions & 3 deletions cqlmapper/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,10 @@ def load_many(cls, conn, keys, concurrency=25):
models with only a single PRIMARY KEY (single partition key and no
clustering keys) a simple list of values may be used. For cases with
multiple PRIMARY KEYS, a list of dicts mapping each primary key to
it's value must be given.
it's value must be given. The primary key name given in this dict must
match the table name, so if you are using `db_field` in that column, you
should use _that_ value, not the name of the Column field on your
cqlmapper model.
.. code-block:: python
Expand All @@ -648,6 +651,10 @@ class ComplexModel(Model):
ck = columns.Integer(primary_key=True) # clustering
value = columns.Text()
class DBFieldModel(Model):
_key = columns.Text(primary_key=True, db_field="key")
value = columns.Text()
valid_simple = SimpleModel.load_many(conn, ["fizz", "buzz"])
valid_simple = SimpleModel.load_many(conn, [{"key": "fizz"}, {"key: "buzz"}])
try:
Expand All @@ -667,6 +674,13 @@ class ComplexModel(Model):
except Exception:
pass
valid_db_field = DBFieldModel.load_many(conn, ["fizz", "buzz"])
valid_db_field = DBFieldModel.load_many(conn, [{"key": "fizz"}, {"key: "buzz"}])
try:
invalid_db_field = DBFieldModel.load_many(conn, [{"_key: "buzz"}])
except Exception:
pass
:type: List[Dict[str, Any]] or List[Any]
:param concurrency: Maximum number of queries to run concurrently.
:type: int
Expand All @@ -678,7 +692,7 @@ class ComplexModel(Model):
raise ValueError("'concurrency' in 'load_many' must be >= 1.")

# cls._primary_keys is an OrderedDict so no need to sort the keys
pks = list(cls._primary_keys.keys())
pks = [col.db_field_name for col in cls._primary_keys.values()]

# Support the "simple" format for Models that allow it
if len(pks) == 1 and not isinstance(keys[0], dict):
Expand All @@ -687,7 +701,7 @@ class ComplexModel(Model):
parameters = [tuple(key_values[key] for key in pks) for key_values in keys]
args_str = " AND ".join("{key} = ?".format(key=key) for key in pks)
# cls._columns is an OrderedDict so no need to sort the keys
cols = ",".join(col for col in cls._columns.keys())
cols = ",".join(col.db_field_name for col in cls._columns.values())
statement = conn.session.prepare(
"SELECT {columns} FROM {cf_name} WHERE {args}".format(
columns=cols, cf_name=cls.column_family_name(), args=args_str
Expand Down
14 changes: 14 additions & 0 deletions tests/integration/model/test_model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,11 @@ class SimpleModel(Model):
value = columns.Text()


class DBFieldNameModel(Model):
_pk = columns.Text(primary_key=True, db_field="pk")
value = columns.Text()


class ComplexModel(Model):
pk = columns.Text(primary_key=True)
ck = columns.Integer(primary_key=True)
Expand All @@ -977,19 +982,22 @@ def setUpClass(cls):
conn = cls.connection()
sync_table(conn, SimpleModel)
sync_table(conn, ComplexModel)
sync_table(conn, DBFieldNameModel)
SimpleModel.create(conn, key="alpha", value="omega")
SimpleModel.create(conn, key="foo", value="bar")
SimpleModel.create(conn, key="zip", value="zap")
ComplexModel.create(conn, pk="fizz", ck=0, value="buzz")
ComplexModel.create(conn, pk="fizz", ck=1, value="hunter2")
ComplexModel.create(conn, pk="key", ck=0, value="value")
DBFieldNameModel.create(conn, _pk="hunter", value="42")

@classmethod
def tearDownClass(cls):
super(TestLoadMany, cls).tearDownClass()
conn = cls.connection()
drop_table(conn, SimpleModel)
drop_table(conn, ComplexModel)
drop_table(conn, DBFieldNameModel)

def test_empty_keys(self):
conn = Mock()
Expand Down Expand Up @@ -1048,3 +1056,9 @@ def test_complex_model_missing_key(self):
def test_complex_model_simple_input(self):
with self.assertRaises(TypeError):
ComplexModel.load_many(self.conn, ["fizz"])

def test_db_field_name_models(self):
models = DBFieldNameModel.load_many(self.conn, ["hunter"])
self.assertEqual(models, [DBFieldNameModel(_pk="hunter", value="42")])
models = DBFieldNameModel.load_many(self.conn, [{"pk": "hunter"}])
self.assertEqual(models, [DBFieldNameModel(_pk="hunter", value="42")])

0 comments on commit 5f9952f

Please sign in to comment.