diff --git a/.gitignore b/.gitignore index 7cbf5f7..c317951 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ *.so # Distribution / packaging +virtualenv .Python env/ venv/ @@ -34,7 +35,7 @@ pip-delete-this-directory.txt .cache nosetests.xml coverage.xml - +cover # Translations *.mo @@ -55,3 +56,4 @@ docs/_build/ # editor stuffs *.swp +.idea \ No newline at end of file diff --git a/AUTHORS.rst b/AUTHORS.rst index 5e040ea..385aa70 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -15,3 +15,4 @@ AUTHORS are (and/or have been):: * Igor `mastak` * Hans Lellelid * `iceboy-sjtu` + * Sergio Medina Toledo diff --git a/aioamqp/channel.py b/aioamqp/channel.py index 759452f..4a8b14c 100644 --- a/aioamqp/channel.py +++ b/aioamqp/channel.py @@ -36,6 +36,10 @@ def __init__(self, protocol, channel_id): self._futures = {} self._ctag_events = {} + def __del__(self): + for queue in self.consumer_queues.values(): + asyncio.ensure_future(queue.put(StopIteration()), loop=self._loop) + def _set_waiter(self, rpc_name): if rpc_name in self._futures: raise exceptions.SynchronizationError("Waiter already exists") @@ -70,7 +74,7 @@ def connection_closed(self, server_code=None, server_reason=None, exception=None self.protocol.release_channel_id(self.channel_id) for queue in self.consumer_queues.values(): - yield from queue.put(None) + asyncio.ensure_future(queue.put(StopIteration()), loop=self._loop) self.close_event.set() @asyncio.coroutine @@ -169,7 +173,7 @@ def close_ok(self, frame): logger.info("Channel closed") for queue in self.consumer_queues.values(): - yield from queue.put(None) + yield from queue.put(StopIteration()) self.protocol.release_channel_id(self.channel_id) @@ -670,7 +674,7 @@ def server_basic_cancel(self, frame): consumer_tag = frame.arguments['consumer_tag'] self.cancelled_consumers.add(consumer_tag) consumer_queue = self.consumer_queues[consumer_tag] - yield from consumer_queue.put(None) + yield from consumer_queue.put(StopIteration()) logger.info("consume cancelled received") @asyncio.coroutine @@ -695,7 +699,7 @@ def basic_cancel_ok(self, frame): future.set_result(results) consumer_queue = self.consumer_queues[results["consumer_tag"]] - yield from consumer_queue.put(None) + yield from consumer_queue.put(StopIteration()) logger.debug("Cancel ok") diff --git a/aioamqp/consumer.py b/aioamqp/consumer.py index 27c9715..16ff899 100644 --- a/aioamqp/consumer.py +++ b/aioamqp/consumer.py @@ -3,30 +3,50 @@ import sys +import logging + PY35 = sys.version_info >= (3, 5) +logger = logging.getLogger(__name__) + + +class ConsumerStoped(Exception): + pass class Consumer: def __init__(self, queue: asyncio.Queue, consumer_tag): - self.queue = queue + self._queue = queue self.tag = consumer_tag self.message = None + self._stoped = False if PY35: async def __aiter__(self): return self async def __anext__(self): - return self.fetch_message() + if not self._stoped: + self.message = await self._queue.get() + if isinstance(self.message, StopIteration): + self._stoped = True + raise StopAsyncIteration() + else: + return self.message + raise StopAsyncIteration() @asyncio.coroutine def fetch_message(self): - - self.message = yield from self.queue.get() - if self.message: - return self.message + if not self._stoped: + self.message = yield from self._queue.get() + if isinstance(self.message, StopIteration): + self._stoped = True + return False + else: + return True else: - raise StopIteration() + return False def get_message(self): - return self.message \ No newline at end of file + if self._stoped: + raise ConsumerStoped() + return self.message diff --git a/aioamqp/protocol.py b/aioamqp/protocol.py index 9c916c4..d1b61c9 100644 --- a/aioamqp/protocol.py +++ b/aioamqp/protocol.py @@ -12,12 +12,10 @@ from . import version from .compat import ensure_future - logger = logging.getLogger(__name__) class _StreamWriter(asyncio.StreamWriter): - def write(self, data): ret = super().write(data) self._protocol._heartbeat_timer_send_reset() @@ -135,7 +133,7 @@ def close_ok(self, frame): @asyncio.coroutine def start_connection(self, host, port, login, password, virtualhost, ssl=False, - login_method='AMQPLAIN', insist=False): + login_method='AMQPLAIN', insist=False): """Initiate a connection at the protocol level We send `PROTOCOL_HEADER' """ @@ -194,8 +192,8 @@ def start_connection(self, host, port, login, password, virtualhost, ssl=False, frame = yield from self.get_frame() yield from self.dispatch_frame(frame) if (frame.frame_type == amqp_constants.TYPE_METHOD and - frame.class_id == amqp_constants.CLASS_CONNECTION and - frame.method_id == amqp_constants.CONNECTION_CLOSE): + frame.class_id == amqp_constants.CLASS_CONNECTION and + frame.method_id == amqp_constants.CONNECTION_CLOSE): raise exceptions.AmqpClosedConnection() # for now, we read server's responses asynchronously @@ -376,10 +374,9 @@ def server_close(self, frame): method_id = response.read_short() self.stop() logger.warning("Server closed connection: %s, code=%s, class_id=%s, method_id=%s", - reply_text, reply_code, class_id, method_id) + reply_text, reply_code, class_id, method_id) self._close_channels(reply_code, reply_text) - @asyncio.coroutine def tune(self, frame): decoder = amqp_frame.AmqpDecoder(frame.payload) diff --git a/aioamqp/tests/test_channel.py b/aioamqp/tests/test_channel.py index 06e91df..06449c8 100644 --- a/aioamqp/tests/test_channel.py +++ b/aioamqp/tests/test_channel.py @@ -5,14 +5,16 @@ import os import unittest +import asyncio + from . import testcase from . import testing from .. import exceptions IMPLEMENT_CHANNEL_FLOW = os.environ.get('IMPLEMENT_CHANNEL_FLOW', False) -class ChannelTestCase(testcase.RabbitTestCase, unittest.TestCase): +class ChannelTestCase(testcase.RabbitTestCase, unittest.TestCase): _multiprocess_can_split_ = True @testing.coroutine @@ -83,9 +85,37 @@ def test_channel_active_inactive_flow(self): result = yield from channel.flow(active=False) self.assertFalse(result['active']) + @testing.coroutine + def test_channel_cancel_stops_consumer(self): + # declare + yield from self.channel.queue_declare("q", exclusive=True, no_wait=False) + yield from self.channel.exchange_declare("e", "fanout") + yield from self.channel.queue_bind("q", "e", routing_key='') -class ChannelIdTestCase(testcase.RabbitTestCase, unittest.TestCase): + # get a different channel + channel = yield from self.create_channel() + + # publish + yield from channel.publish("coucou", "e", routing_key='', ) + + consumer_stoped = asyncio.Future() + @asyncio.coroutine + def consumer_task(consumer): + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + + consumer_stoped.set_result(True) + + consumer = yield from channel.basic_consume(queue_name="q") + asyncio.get_event_loop().create_task(consumer_task(consumer)) + + yield from channel.basic_cancel(consumer.tag) + + assert (yield from consumer_stoped) + + +class ChannelIdTestCase(testcase.RabbitTestCase, unittest.TestCase): @testing.coroutine def test_channel_id_release_close(self): channels_count_start = self.amqp.channels_ids_count diff --git a/aioamqp/version.py b/aioamqp/version.py index 4e431bf..a322420 100644 --- a/aioamqp/version.py +++ b/aioamqp/version.py @@ -1,2 +1,2 @@ -__version__ = '0.8.2' +__version__ = '0.9.0' __packagename__ = 'aioamqp' diff --git a/docs/api.rst b/docs/api.rst index 46bb428..8f6c727 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -8,10 +8,11 @@ API Basics ------ -There are two principal objects when using aioamqp: +There are three principal objects when using aioamqp: * The protocol object, used to begin a connection to aioamqp, * The channel object, used when creating a new channel to effectively use an AMQP channel. + * The consumer object, used for consuming messages from a queue. Starting a connection @@ -141,20 +142,21 @@ When consuming message, you connect to the same queue you previously created:: import asyncio import aioamqp - @asyncio.coroutine - def callback(body, envelope, properties): - print(body) - channel = yield from protocol.channel() - yield from channel.basic_consume(callback, queue_name="my_queue") + consumer = yield from channel.basic_consume(queue_name="my_queue") + + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + print(body) -The ``basic_consume`` method tells the server to send us the messages, and will call ``callback`` with amqp response arguments. +The ``basic_consume`` method tells the server to send us the messages, and will add the amqp response arguments to the consumer queue +from where you can get them and process using an async iterator or a while using the ``fetch_message`` and ``get_message`` methods of the consumer class. -The ``consumer_tag`` is the id of your consumer, and the ``delivery_tag`` is the tag used if you want to acknowledge the message. +The ``consumer.tag`` or ``envelope.consumer_tag`` is the id of your consumer, and the ``envelope.delivery_tag`` is the tag used if you want to acknowledge the message. -In the callback: +In the while: -* the first ``body`` parameter is the message +* the ``body`` returned by ``get_message`` is the message * the ``envelope`` is an instance of envelope.Envelope class which encapsulate a group of amqp parameter such as:: consumer_tag diff --git a/docs/changelog.rst b/docs/changelog.rst index e88376b..b9ac7bf 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,11 @@ Changelog ========= +Aioamqp 0.9.0 +------------- + + * Changed consumer API from callbacks to coroutines solving the problem of publish inside of a consumer callback + Aioamqp 0.8.2 ------------- diff --git a/docs/examples/hello_world.rst b/docs/examples/hello_world.rst index 2eafe57..8337011 100644 --- a/docs/examples/hello_world.rst +++ b/docs/examples/hello_world.rst @@ -63,13 +63,21 @@ We have to ensure the queue is created. Queue declaration is indempotant. yield from channel.queue_declare(queue_name='hello') -To consume a message, the library calls a callback (which **MUST** be a coroutine): +To consume a message is used the Consumer object using an async iter if is python 3.5 or while with ``fetch_message`` and ``get_message`` if is python < 3.5: .. code-block:: python - @asyncio.coroutine - def callback(channel, body, envelope, properties): + consumer = yield from channel.basic_consume(queue_name='hello', no_ack=True) + + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() print(body) - yield from channel.basic_consume(callback, queue_name='hello', no_ack=True) +For python 3.5 : + +.. code-block:: python + consumer = await channel.basic_consume(queue_name='hello', no_ack=True) + + async for channel, body, envelope, properties in consumer: + print(body) diff --git a/docs/examples/rpc.rst b/docs/examples/rpc.rst index 13d4dd1..47ce952 100644 --- a/docs/examples/rpc.rst +++ b/docs/examples/rpc.rst @@ -40,12 +40,12 @@ Note: the client use a `waiter` (an asyncio.Event) which will be set when receiv Server ------ -When unqueing a message, the server will publish a response directly in the callback. The `correlation_id` is used to let the client know it's a response from this request. +When unqueing a message, the server will enqueue a response directly in the consumer loop. The `correlation_id` is used to let the client know it's a response from this request. .. code-block:: python - @asyncio.coroutine - def on_request(channel, body, envelope, properties): + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() n = int(body) print(" [.] fib(%s)" % n) diff --git a/docs/examples/work_queue.rst b/docs/examples/work_queue.rst index 0b2bb34..05ee3af 100644 --- a/docs/examples/work_queue.rst +++ b/docs/examples/work_queue.rst @@ -41,14 +41,14 @@ Then, the worker configure the `QOS`: it specifies how the worker unqueues messa yield from channel.basic_qos(prefetch_count=1, prefetch_size=0, connection_global=False) -Finaly we have to create a callback that will `ack` the message to mark it as `processed`. -Note: the code in the callback calls `asyncio.sleep` to simulate an asyncio compatible task that takes time. +Finaly we have to create a messages processor that will `ack` the message to mark it as `processed`. +Note: the code in the while calls `asyncio.sleep` to simulate an asyncio compatible task that takes time. You probably want to block the eventloop to simulate a CPU intensive task using `time.sleep`. .. code-block:: python - @asyncio.coroutine - def callback(channel, body, envelope, properties): + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() print(" [x] Received %r" % body) yield from asyncio.sleep(body.count(b'.')) print(" [x] Done") diff --git a/examples/receive.py b/examples/receive.py index 062b7bc..6920b5c 100644 --- a/examples/receive.py +++ b/examples/receive.py @@ -7,10 +7,6 @@ import aioamqp -@asyncio.coroutine -def callback(channel, body, envelope, properties): - print(" [x] Received %r" % body) - @asyncio.coroutine def receive(): transport, protocol = yield from aioamqp.connect() @@ -18,8 +14,11 @@ def receive(): yield from channel.queue_declare(queue_name='hello') - yield from channel.basic_consume(callback, queue_name='hello') + consumer = yield from channel.basic_consume(queue_name='hello') + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + print(" [x] Received %r" % body) event_loop = asyncio.get_event_loop() event_loop.run_until_complete(receive()) diff --git a/examples/receive_log.py b/examples/receive_log.py index 9292fd8..f2523f8 100644 --- a/examples/receive_log.py +++ b/examples/receive_log.py @@ -11,12 +11,6 @@ import random - -@asyncio.coroutine -def callback(channel, body, envelope, properties): - print(" [x] %r" % body) - - @asyncio.coroutine def receive_log(): try: @@ -39,7 +33,11 @@ def receive_log(): print(' [*] Waiting for logs. To exit press CTRL+C') - yield from channel.basic_consume(callback, queue_name=queue_name, no_ack=True) + consumer = yield from channel.basic_consume(queue_name=queue_name, no_ack=True) + + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + print(" [x] %r" % body) event_loop = asyncio.get_event_loop() event_loop.run_until_complete(receive_log()) diff --git a/examples/receive_log_direct.py b/examples/receive_log_direct.py index b272d8b..0d77014 100644 --- a/examples/receive_log_direct.py +++ b/examples/receive_log_direct.py @@ -14,11 +14,7 @@ @asyncio.coroutine -def callback(channel, body, envelope, properties): - print("consumer {} recved {} ({})".format(envelope.consumer_tag, body, envelope.delivery_tag)) - -@asyncio.coroutine -def receive_log(waiter): +def receive_log(): try: transport, protocol = yield from aioamqp.connect('localhost', 5672) except aioamqp.AmqpClosedConnection: @@ -48,16 +44,16 @@ def receive_log(waiter): print(' [*] Waiting for logs. To exit press CTRL+C') - yield from asyncio.wait_for(channel.basic_consume(callback, queue_name=queue_name), timeout=10) - yield from waiter.wait() + consumer = yield from channel.basic_consume(queue_name=queue_name) + + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + print("consumer {} recved {} ({})".format(envelope.consumer_tag, body, envelope.delivery_tag)) yield from protocol.close() transport.close() + loop = asyncio.get_event_loop() +loop.run_until_complete(receive_log()) -try: - waiter = asyncio.Event() - loop.run_until_complete(receive_log(waiter)) -except KeyboardInterrupt: - waiter.set() diff --git a/examples/receive_log_topic.py b/examples/receive_log_topic.py index 079be2b..a859564 100644 --- a/examples/receive_log_topic.py +++ b/examples/receive_log_topic.py @@ -13,11 +13,6 @@ import sys -@asyncio.coroutine -def callback(channel, body, envelope, properties): - print("consumer {} received {} ({})".format(envelope.consumer_tag, body, envelope.delivery_tag)) - - @asyncio.coroutine def receive_log(): try: @@ -48,7 +43,11 @@ def receive_log(): print(' [*] Waiting for logs. To exit press CTRL+C') - yield from channel.basic_consume(callback, queue_name=queue_name) + consumer = yield from channel.basic_consume(queue_name=queue_name) + + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + print("consumer {} received {} ({})".format(envelope.consumer_tag, body, envelope.delivery_tag)) event_loop = asyncio.get_event_loop() event_loop.run_until_complete(receive_log()) diff --git a/examples/rpc_client.py b/examples/rpc_client.py index 9883c3e..93bdbc3 100644 --- a/examples/rpc_client.py +++ b/examples/rpc_client.py @@ -17,6 +17,7 @@ def __init__(self): self.protocol = None self.channel = None self.callback_queue = None + self.consumer_task = None self.waiter = asyncio.Event() @asyncio.coroutine @@ -28,18 +29,21 @@ def connect(self): result = yield from self.channel.queue_declare(queue_name='', exclusive=True) self.callback_queue = result['queue'] - yield from self.channel.basic_consume( + consumer = yield from self.channel.basic_consume( self.on_response, no_ack=True, queue_name=self.callback_queue, ) + self.consumer_task = asyncio.get_event_loop().create_task(self.on_response(consumer)) @asyncio.coroutine - def on_response(self, channel, body, envelope, properties): - if self.corr_id == properties.correlation_id: - self.response = body + def on_response(self, consumer): + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + if self.corr_id == properties.correlation_id: + self.response = body - self.waiter.set() + self.waiter.set() @asyncio.coroutine def call(self, n): diff --git a/examples/rpc_server.py b/examples/rpc_server.py index 42309b2..c2f13f7 100644 --- a/examples/rpc_server.py +++ b/examples/rpc_server.py @@ -15,25 +15,6 @@ def fib(n): return fib(n-1) + fib(n-2) -@asyncio.coroutine -def on_request(channel, body, envelope, properties): - n = int(body) - - print(" [.] fib(%s)" % n) - response = fib(n) - - yield from channel.basic_publish( - payload=str(response), - exchange_name='', - routing_key=properties.reply_to, - properties={ - 'correlation_id': properties.correlation_id, - }, - ) - - yield from channel.basic_client_ack(delivery_tag=envelope.delivery_tag) - - @asyncio.coroutine def rpc_server(): @@ -43,9 +24,26 @@ def rpc_server(): yield from channel.queue_declare(queue_name='rpc_queue') yield from channel.basic_qos(prefetch_count=1, prefetch_size=0, connection_global=False) - yield from channel.basic_consume(on_request, queue_name='rpc_queue') + consumer = yield from channel.basic_consume(queue_name='rpc_queue') print(" [x] Awaiting RPC requests") + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + n = int(body) + + print(" [.] fib(%s)" % n) + response = fib(n) + + yield from channel.basic_publish( + payload=str(response), + exchange_name='', + routing_key=properties.reply_to, + properties={ + 'correlation_id': properties.correlation_id, + }, + ) + + yield from channel.basic_client_ack(delivery_tag=envelope.delivery_tag) event_loop = asyncio.get_event_loop() event_loop.run_until_complete(rpc_server()) diff --git a/examples/worker.py b/examples/worker.py index bfe530b..b6e249f 100644 --- a/examples/worker.py +++ b/examples/worker.py @@ -8,13 +8,6 @@ import sys -@asyncio.coroutine -def callback(channel, body, envelope, properties): - print(" [x] Received %r" % body) - yield from asyncio.sleep(body.count(b'.')) - print(" [x] Done") - yield from channel.basic_client_ack(delivery_tag=envelope.delivery_tag) - @asyncio.coroutine def worker(): @@ -24,17 +17,20 @@ def worker(): print("closed connections") return - channel = yield from protocol.channel() yield from channel.queue(queue_name='task_queue', durable=True) yield from channel.basic_qos(prefetch_count=1, prefetch_size=0, connection_global=False) - yield from channel.basic_consume(callback, queue_name='task_queue') + consumer = yield from channel.basic_consume(queue_name='task_queue') + + while (yield from consumer.fetch_message()): + channel, body, envelope, properties = consumer.get_message() + print(" [x] Received %r" % body) + yield from asyncio.sleep(body.count(b'.')) + print(" [x] Done") + yield from channel.basic_client_ack(delivery_tag=envelope.delivery_tag) event_loop = asyncio.get_event_loop() event_loop.run_until_complete(worker()) event_loop.run_forever() - - -