Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] First implementation of a new consumer API #118

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__/
*.so

# Distribution / packaging
virtualenv
.Python
env/
venv/
Expand Down Expand Up @@ -34,7 +35,7 @@ pip-delete-this-directory.txt
.cache
nosetests.xml
coverage.xml

cover
# Translations
*.mo

Expand All @@ -55,3 +56,4 @@ docs/_build/

# editor stuffs
*.swp
.idea
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ AUTHORS are (and/or have been)::
* Igor `mastak`
* Hans Lellelid
* `iceboy-sjtu`
* Sergio Medina Toledo
39 changes: 28 additions & 11 deletions aioamqp/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import io
from itertools import count

from aioamqp.consumer import Consumer

from . import constants as amqp_constants
from . import frame as amqp_frame
from . import exceptions
Expand All @@ -24,7 +26,6 @@ def __init__(self, protocol, channel_id):
self.protocol = protocol
self.channel_id = channel_id
self.consumer_queues = {}
self.consumer_callbacks = {}
self.response_future = None
self.close_event = asyncio.Event(loop=self._loop)
self.cancelled_consumers = set()
Expand All @@ -35,6 +36,10 @@ def __init__(self, protocol, channel_id):
self._futures = {}
self._ctag_events = {}

def __del__(self):
for queue in self.consumer_queues.values():
asyncio.ensure_future(queue.put(StopIteration()), loop=self._loop)

def _set_waiter(self, rpc_name):
if rpc_name in self._futures:
raise exceptions.SynchronizationError("Waiter already exists")
Expand Down Expand Up @@ -67,6 +72,9 @@ def connection_closed(self, server_code=None, server_reason=None, exception=None
future.set_exception(exception)

self.protocol.release_channel_id(self.channel_id)

for queue in self.consumer_queues.values():
asyncio.ensure_future(queue.put(StopIteration()), loop=self._loop)
self.close_event.set()

@asyncio.coroutine
Expand Down Expand Up @@ -163,6 +171,10 @@ def close(self, reply_code=0, reply_text="Normal Shutdown", timeout=None):
def close_ok(self, frame):
self._get_waiter('close').set_result(True)
logger.info("Channel closed")

for queue in self.consumer_queues.values():
yield from queue.put(StopIteration())

self.protocol.release_channel_id(self.channel_id)

@asyncio.coroutine
Expand Down Expand Up @@ -573,13 +585,12 @@ def basic_server_nack(self, frame, delivery_tag=None):
fut.set_exception(exceptions.PublishFailed(delivery_tag))

@asyncio.coroutine
def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False, no_ack=False,
def basic_consume(self, queue_name='', consumer_tag='', no_local=False, no_ack=False,
exclusive=False, no_wait=False, arguments=None, timeout=None):
"""Starts the consumption of message into a queue.
the callback will be called each time we're receiving a message.

Args:
callback: coroutine, the called callback
queue_name: str, the queue to receive message from
consumer_tag: str, optional consumer tag
no_local: bool, if set the server will not send messages
Expand Down Expand Up @@ -609,16 +620,15 @@ def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False
request.write_bits(no_local, no_ack, exclusive, no_wait)
request.write_table(arguments)

self.consumer_callbacks[consumer_tag] = callback
self.last_consumer_tag = consumer_tag
self.consumer_queues[consumer_tag] = asyncio.Queue()

consumer = Consumer(self.consumer_queues[consumer_tag], consumer_tag)

return_value = yield from self._write_frame_awaiting_response(
'basic_consume', frame, request, no_wait, timeout=timeout)
if no_wait:
return_value = {'consumer_tag': consumer_tag}
else:
if not no_wait:
self._ctag_events[consumer_tag].set()
return return_value
return consumer

@asyncio.coroutine
def basic_consume_ok(self, frame):
Expand Down Expand Up @@ -649,20 +659,22 @@ def basic_deliver(self, frame):
envelope = Envelope(consumer_tag, delivery_tag, exchange_name, routing_key, is_redeliver)
properties = content_header_frame.properties

callback = self.consumer_callbacks[consumer_tag]
consumer_queue = self.consumer_queues[consumer_tag]

event = self._ctag_events.get(consumer_tag)
if event:
yield from event.wait()
del self._ctag_events[consumer_tag]

yield from callback(self, body, envelope, properties)
yield from consumer_queue.put((self, body, envelope, properties))

@asyncio.coroutine
def server_basic_cancel(self, frame):
"""From the server, means the server won't send anymore messages to this consumer."""
consumer_tag = frame.arguments['consumer_tag']
self.cancelled_consumers.add(consumer_tag)
consumer_queue = self.consumer_queues[consumer_tag]
yield from consumer_queue.put(StopIteration())
logger.info("consume cancelled received")

@asyncio.coroutine
Expand All @@ -674,6 +686,7 @@ def basic_cancel(self, consumer_tag, no_wait=False, timeout=None):
request = amqp_frame.AmqpEncoder()
request.write_shortstr(consumer_tag)
request.write_bits(no_wait)

return (yield from self._write_frame_awaiting_response(
'basic_cancel', frame, request, no_wait=no_wait, timeout=timeout))

Expand All @@ -684,6 +697,10 @@ def basic_cancel_ok(self, frame):
}
future = self._get_waiter('basic_cancel')
future.set_result(results)

consumer_queue = self.consumer_queues[results["consumer_tag"]]
yield from consumer_queue.put(StopIteration())

logger.debug("Cancel ok")

@asyncio.coroutine
Expand Down
54 changes: 54 additions & 0 deletions aioamqp/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
import asyncio

import sys

import logging

PY35 = sys.version_info >= (3, 5)
logger = logging.getLogger(__name__)


class ConsumerStoped(Exception):
pass


class Consumer:
def __init__(self, queue: asyncio.Queue, consumer_tag):
self._queue = queue
self.tag = consumer_tag
self.message = None
self._stoped = False

if PY35:
@asyncio.coroutine
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
if not self._stoped:
self.message = yield from self._queue.get()
if isinstance(self.message, StopIteration):
self._stoped = True
raise StopAsyncIteration()
else:
return self.message
raise StopAsyncIteration()

@asyncio.coroutine
def fetch_message(self):
if not self._stoped:
self.message = yield from self._queue.get()
if isinstance(self.message, StopIteration):
self._stoped = True
return False
else:
return True
else:
return False

def get_message(self):
if self._stoped:
raise ConsumerStoped()
return self.message
11 changes: 4 additions & 7 deletions aioamqp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
from . import version
from .compat import ensure_future


logger = logging.getLogger(__name__)


class _StreamWriter(asyncio.StreamWriter):

def write(self, data):
ret = super().write(data)
self._protocol._heartbeat_timer_send_reset()
Expand Down Expand Up @@ -135,7 +133,7 @@ def close_ok(self, frame):

@asyncio.coroutine
def start_connection(self, host, port, login, password, virtualhost, ssl=False,
login_method='AMQPLAIN', insist=False):
login_method='AMQPLAIN', insist=False):
"""Initiate a connection at the protocol level
We send `PROTOCOL_HEADER'
"""
Expand Down Expand Up @@ -194,8 +192,8 @@ def start_connection(self, host, port, login, password, virtualhost, ssl=False,
frame = yield from self.get_frame()
yield from self.dispatch_frame(frame)
if (frame.frame_type == amqp_constants.TYPE_METHOD and
frame.class_id == amqp_constants.CLASS_CONNECTION and
frame.method_id == amqp_constants.CONNECTION_CLOSE):
frame.class_id == amqp_constants.CLASS_CONNECTION and
frame.method_id == amqp_constants.CONNECTION_CLOSE):
raise exceptions.AmqpClosedConnection()

# for now, we read server's responses asynchronously
Expand Down Expand Up @@ -376,10 +374,9 @@ def server_close(self, frame):
method_id = response.read_short()
self.stop()
logger.warning("Server closed connection: %s, code=%s, class_id=%s, method_id=%s",
reply_text, reply_code, class_id, method_id)
reply_text, reply_code, class_id, method_id)
self._close_channels(reply_code, reply_text)


@asyncio.coroutine
def tune(self, frame):
decoder = amqp_frame.AmqpDecoder(frame.payload)
Expand Down
71 changes: 43 additions & 28 deletions aioamqp/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,13 @@ class BasicCancelTestCase(testcase.RabbitTestCase, unittest.TestCase):
@testing.coroutine
def test_basic_cancel(self):

@asyncio.coroutine
def callback(channel, body, envelope, _properties):
pass

queue_name = 'queue_name'
exchange_name = 'exchange_name'
yield from self.channel.queue_declare(queue_name)
yield from self.channel.exchange_declare(exchange_name, type_name='direct')
yield from self.channel.queue_bind(queue_name, exchange_name, routing_key='')
result = yield from self.channel.basic_consume(callback, queue_name=queue_name)
result = yield from self.channel.basic_cancel(result['consumer_tag'])
consumer = yield from self.channel.basic_consume(queue_name=queue_name)
result = yield from self.channel.basic_cancel(consumer.tag)

result = yield from self.channel.publish("payload", exchange_name, routing_key='')

Expand Down Expand Up @@ -139,11 +135,16 @@ def test_ack_message(self):

qfuture = asyncio.Future(loop=self.loop)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
qfuture.set_result(envelope)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
qfuture.set_result(envelope)

asyncio.get_event_loop().create_task(consumer_task(consumer))

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
envelope = yield from qfuture

yield from qfuture
Expand All @@ -162,13 +163,18 @@ def test_basic_nack(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
yield from self.channel.basic_client_nack(
envelope.delivery_tag, multiple=True, requeue=False
)
qfuture.set_result(True)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
yield from self.channel.basic_client_nack(
envelope.delivery_tag, multiple=True, requeue=False
)
qfuture.set_result(True)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)

asyncio.get_event_loop().create_task(consumer_task(consumer))
yield from qfuture

@testing.coroutine
Expand All @@ -184,11 +190,14 @@ def test_basic_nack_norequeue(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=False)
qfuture.set_result(True)

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=False)
qfuture.set_result(True)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)
asyncio.get_event_loop().create_task(consumer_task(consumer))
yield from qfuture

@testing.coroutine
Expand All @@ -204,11 +213,14 @@ def test_basic_nack_requeue(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=True)
qfuture.set_result(True)

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
yield from self.channel.basic_client_nack(envelope.delivery_tag, requeue=True)
qfuture.set_result(True)

consumer = yield from self.channel.basic_consume(queue_name=queue_name)
asyncio.get_event_loop().create_task(consumer_task(consumer))
yield from qfuture


Expand All @@ -224,10 +236,13 @@ def test_basic_reject(self):
qfuture = asyncio.Future(loop=self.loop)

@asyncio.coroutine
def qcallback(channel, body, envelope, _properties):
qfuture.set_result(envelope)
def consumer_task(consumer):
while (yield from consumer.fetch_message()):
channel, body, envelope, _properties = consumer.get_message()
qfuture.set_result(envelope)

yield from self.channel.basic_consume(qcallback, queue_name=queue_name)
consumer = yield from self.channel.basic_consume(queue_name=queue_name)
asyncio.get_event_loop().create_task(consumer_task(consumer))
envelope = yield from qfuture

yield from self.channel.basic_reject(envelope.delivery_tag)
Loading