Skip to content

Commit

Permalink
First implementation of the new consumer API
Browse files Browse the repository at this point in the history
- Added Consumer class to handle messages consumsion
- Added consumer queues in channels to wire the Consumer with the Channel
- Fixed all test
- Need test of stop of consumsion, documentation and update of examples
  • Loading branch information
Sergio Medina Toledo committed Oct 27, 2016
1 parent 2142244 commit 6ac5bd2
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 80 deletions.
35 changes: 24 additions & 11 deletions aioamqp/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import io
from itertools import count

from aioamqp.consumer import Consumer

from . import constants as amqp_constants
from . import frame as amqp_frame
from . import exceptions
Expand All @@ -24,7 +26,6 @@ def __init__(self, protocol, channel_id):
self.protocol = protocol
self.channel_id = channel_id
self.consumer_queues = {}
self.consumer_callbacks = {}
self.response_future = None
self.close_event = asyncio.Event(loop=self._loop)
self.cancelled_consumers = set()
Expand Down Expand Up @@ -67,6 +68,9 @@ def connection_closed(self, server_code=None, server_reason=None, exception=None
future.set_exception(exception)

self.protocol.release_channel_id(self.channel_id)

for queue in self.consumer_queues.values():
yield from queue.put(None)
self.close_event.set()

@asyncio.coroutine
Expand Down Expand Up @@ -163,6 +167,10 @@ def close(self, reply_code=0, reply_text="Normal Shutdown", timeout=None):
def close_ok(self, frame):
self._get_waiter('close').set_result(True)
logger.info("Channel closed")

for queue in self.consumer_queues.values():
yield from queue.put(None)

self.protocol.release_channel_id(self.channel_id)

@asyncio.coroutine
Expand Down Expand Up @@ -573,13 +581,12 @@ def basic_server_nack(self, frame, delivery_tag=None):
fut.set_exception(exceptions.PublishFailed(delivery_tag))

@asyncio.coroutine
def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False, no_ack=False,
def basic_consume(self, queue_name='', consumer_tag='', no_local=False, no_ack=False,
exclusive=False, no_wait=False, arguments=None, timeout=None):
"""Starts the consumption of message into a queue.
the callback will be called each time we're receiving a message.
Args:
callback: coroutine, the called callback
queue_name: str, the queue to receive message from
consumer_tag: str, optional consumer tag
no_local: bool, if set the server will not send messages
Expand Down Expand Up @@ -609,16 +616,15 @@ def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False
request.write_bits(no_local, no_ack, exclusive, no_wait)
request.write_table(arguments)

self.consumer_callbacks[consumer_tag] = callback
self.last_consumer_tag = consumer_tag
self.consumer_queues[consumer_tag] = asyncio.Queue()

consumer = Consumer(self.consumer_queues[consumer_tag], consumer_tag)

return_value = yield from self._write_frame_awaiting_response(
'basic_consume', frame, request, no_wait, timeout=timeout)
if no_wait:
return_value = {'consumer_tag': consumer_tag}
else:
if not no_wait:
self._ctag_events[consumer_tag].set()
return return_value
return consumer

@asyncio.coroutine
def basic_consume_ok(self, frame):
Expand Down Expand Up @@ -649,20 +655,22 @@ def basic_deliver(self, frame):
envelope = Envelope(consumer_tag, delivery_tag, exchange_name, routing_key, is_redeliver)
properties = content_header_frame.properties

callback = self.consumer_callbacks[consumer_tag]
consumer_queue = self.consumer_queues[consumer_tag]

event = self._ctag_events.get(consumer_tag)
if event:
yield from event.wait()
del self._ctag_events[consumer_tag]

yield from callback(self, body, envelope, properties)
yield from consumer_queue.put((self, body, envelope, properties))

@asyncio.coroutine
def server_basic_cancel(self, frame):
"""From the server, means the server won't send anymore messages to this consumer."""
consumer_tag = frame.arguments['consumer_tag']
self.cancelled_consumers.add(consumer_tag)
consumer_queue = self.consumer_queues[consumer_tag]
yield from consumer_queue.put(None)
logger.info("consume cancelled received")

@asyncio.coroutine
Expand All @@ -674,6 +682,7 @@ def basic_cancel(self, consumer_tag, no_wait=False, timeout=None):
request = amqp_frame.AmqpEncoder()
request.write_shortstr(consumer_tag)
request.write_bits(no_wait)

return (yield from self._write_frame_awaiting_response(
'basic_cancel', frame, request, no_wait=no_wait, timeout=timeout))

Expand All @@ -684,6 +693,10 @@ def basic_cancel_ok(self, frame):
}
future = self._get_waiter('basic_cancel')
future.set_result(results)

consumer_queue = self.consumer_queues[results["consumer_tag"]]
yield from consumer_queue.put(None)

logger.debug("Cancel ok")

@asyncio.coroutine
Expand Down
32 changes: 32 additions & 0 deletions aioamqp/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
import asyncio

import sys

PY35 = sys.version_info >= (3, 5)


class Consumer:
def __init__(self, queue: asyncio.Queue, consumer_tag):
self.queue = queue
self.tag = consumer_tag
self.message = None

if PY35:
async def __aiter__(self):
return self

async def __anext__(self):
return self.fetch_message()

@asyncio.coroutine
def fetch_message(self):

self.message = yield from self.queue.get()
if self.message:
return self.message
else:
raise StopIteration()

def get_message(self):
return self.message
71 changes: 43 additions & 28 deletions aioamqp/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,13 @@ class BasicCancelTestCase(testcase.RabbitTestCase, unittest.TestCase):
@testing.coroutine
def test_basic_cancel(self):

@asyncio.coroutine
def callback(channel, body, envelope, _properties):
pass

queue_name = 'queue_name'
exchange_name = 'exchange_name'
yield from self.channel.queue_declare(queue_name)
yield from self.channel.exchange_declare(exchange_name, type_name='direct')
yield from self.channel.queue_bind(queue_name, exchange_name, routing_key='')
result = yield from self.channel.basic_consume(callback, queue_name=queue_name)
result = yield from self.channel.basic_cancel(result['consumer_tag'])
consumer = yield from self.channel.basic_consume(queue_name=queue_name)
result = yield from self.channel.basic_cancel(consumer.tag)

result = yield from self.channel.publish("payload", exchange_name, routing_key='')

Expand Down Expand Up @@ -139,11 +135,16 @@ def test_ack_message(self):

qfuture = asyncio.Future(loop=self.loop)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
qfuture.set_result(envelope)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
qfuture.set_result(envelope)

asyncio.get_event_loop().create_task(consumer_task(consumer))

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
envelope = yield from qfuture

yield from qfuture
Expand All @@ -162,13 +163,18 @@ def test_basic_nack(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
yield from self.channel.basic_client_nack(
envelope.delivery_tag, multiple=True, requeue=False
)
qfuture.set_result(True)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
yield from self.channel.basic_client_nack(
envelope.delivery_tag, multiple=True, requeue=False
)
qfuture.set_result(True)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)

asyncio.get_event_loop().create_task(consumer_task(consumer))
yield from qfuture

@testing.coroutine
Expand All @@ -184,11 +190,14 @@ def test_basic_nack_norequeue(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=False)
qfuture.set_result(True)

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=False)
qfuture.set_result(True)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)
asyncio.get_event_loop().create_task(consumer_task(consumer))
yield from qfuture

@testing.coroutine
Expand All @@ -204,11 +213,14 @@ def test_basic_nack_requeue(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=True)
qfuture.set_result(True)

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=True)
qfuture.set_result(True)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)
asyncio.get_event_loop().create_task(consumer_task(consumer))
yield from qfuture


Expand All @@ -224,10 +236,13 @@ def test_basic_reject(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
qfuture.set_result(envelope)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
qfuture.set_result(envelope)

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
consumer = yield from self.channel.basic_consume(queue_name=queue_name)
asyncio.get_event_loop().create_task(consumer_task(consumer))
envelope = yield from qfuture

yield from self.channel.basic_reject(envelope.delivery_tag)
2 changes: 1 addition & 1 deletion aioamqp/tests/test_close.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ def test_cannot_consume_after_close(self):
yield from self.channel.queue_declare("q")
yield from channel.close()
with self.assertRaises(exceptions.ChannelClosed):
yield from channel.basic_consume(self.callback)
consumer = yield from channel.basic_consume()
48 changes: 31 additions & 17 deletions aioamqp/tests/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ def setUp(self):
self.consume_future = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def callback(self, channel, body, envelope, properties):
self.consume_future.set_result((body, envelope, properties))
def consumer_task(self, consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, properties = consumer.get_message()
self.consume_future.set_result((body, envelope, properties))

@asyncio.coroutine
def get_callback_result(self):
Expand Down Expand Up @@ -74,7 +76,8 @@ def test_consume(self):

yield from asyncio.sleep(2, loop=self.loop)
# start consume
yield from channel.basic_consume(self.callback, queue_name="q")
consumer = yield from channel.basic_consume(queue_name="q")
asyncio.get_event_loop().create_task(self.consumer_task(consumer))
# required ?
yield from asyncio.sleep(2, loop=self.loop)

Expand Down Expand Up @@ -107,7 +110,8 @@ def test_big_consume(self):
self.assertEqual(1, queues["q"]['messages'])

# start consume
yield from channel.basic_consume(self.callback, queue_name="q")
consumer = yield from channel.basic_consume(queue_name="q")
asyncio.get_event_loop().create_task(self.consumer_task(consumer))

yield from asyncio.sleep(1, loop=self.loop)

Expand All @@ -133,20 +137,27 @@ def test_consume_multiple_queues(self):
q1_future = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def q1_callback(channel, body, envelope, properties):
q1_future.set_result((body, envelope, properties))
def consumer_task1(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, properties = consumer.get_message()
q1_future.set_result((body, envelope, properties))

q2_future = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def q2_callback(channel, body, envelope, properties):
q2_future.set_result((body, envelope, properties))
def consumer_task2(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, properties = consumer.get_message()
q2_future.set_result((body, envelope, properties))

# start consumers
result = yield from channel.basic_consume(q1_callback, queue_name="q1")
ctag_q1 = result['consumer_tag']
result = yield from channel.basic_consume(q2_callback, queue_name="q2")
ctag_q2 = result['consumer_tag']
consumer1 = yield from channel.basic_consume(queue_name="q1")
asyncio.get_event_loop().create_task(consumer_task1(consumer1))

ctag_q1 = consumer1.tag
consumer2 = yield from channel.basic_consume(queue_name="q2")
asyncio.get_event_loop().create_task(consumer_task2(consumer2))
ctag_q2 = consumer2.tag

# put message in q1
yield from channel.publish("coucou1", "e", "q1")
Expand All @@ -171,10 +182,10 @@ def q2_callback(channel, body, envelope, properties):
def test_duplicate_consumer_tag(self):
yield from self.channel.queue_declare("q1", exclusive=True, no_wait=False)
yield from self.channel.queue_declare("q2", exclusive=True, no_wait=False)
yield from self.channel.basic_consume(self.callback, queue_name="q1", consumer_tag='tag')
consumer = yield from self.channel.basic_consume(queue_name="q1", consumer_tag='tag')

with self.assertRaises(exceptions.ChannelClosed) as cm:
yield from self.channel.basic_consume(self.callback, queue_name="q2", consumer_tag='tag')
consumer = yield from self.channel.basic_consume(queue_name="q2", consumer_tag='tag')

self.assertEqual(cm.exception.code, 530)

Expand All @@ -198,8 +209,11 @@ def test_consume_callaback_synced(self):
sync_future = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def callback(channel, body, envelope, properties):
self.assertTrue(sync_future.done())
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, properties = consumer.get_message()
self.assertTrue(sync_future.done())

yield from channel.basic_consume(callback, queue_name="q")
consumer = yield from channel.basic_consume(queue_name="q")
sync_future.set_result(True)
asyncio.get_event_loop().create_task(self.consumer_task(consumer))
Loading

0 comments on commit 6ac5bd2

Please sign in to comment.