diff --git a/aioamqp/channel.py b/aioamqp/channel.py index d3f249d..feee173 100644 --- a/aioamqp/channel.py +++ b/aioamqp/channel.py @@ -11,6 +11,7 @@ from . import constants as amqp_constants from . import frame as amqp_frame from . import exceptions +from .compat import iscoroutinepartial from .envelope import Envelope @@ -579,7 +580,7 @@ def basic_consume(self, callback, queue_name='', consumer_tag='', no_local=False the callback will be called each time we're receiving a message. Args: - callback: coroutine, the called callback + callback: callable, the called callback queue_name: str, the queue to receive message from consumer_tag: str, optional consumer tag no_local: bool, if set the server will not send messages @@ -656,7 +657,11 @@ def basic_deliver(self, frame): yield from event.wait() del self._ctag_events[consumer_tag] - yield from callback(self, body, envelope, properties) + if iscoroutinepartial(callback): + yield from callback(self, body, envelope, properties) + else: + self._loop.call_soon(callback, self, body, envelope, properties) + @asyncio.coroutine def server_basic_cancel(self, frame): diff --git a/aioamqp/compat.py b/aioamqp/compat.py index 64f7359..5fc31df 100644 --- a/aioamqp/compat.py +++ b/aioamqp/compat.py @@ -9,3 +9,17 @@ from asyncio import ensure_future except ImportError: ensure_future = asyncio.async + + +def iscoroutinepartial(fn): + # http://bugs.python.org/issue23519 + + while True: + parent = fn + + fn = getattr(parent, 'func', None) + + if fn is None: + break + + return asyncio.iscoroutinefunction(parent) diff --git a/aioamqp/tests/test_consume.py b/aioamqp/tests/test_consume.py index a820ac9..41da6f0 100644 --- a/aioamqp/tests/test_consume.py +++ b/aioamqp/tests/test_consume.py @@ -21,6 +21,9 @@ def setUp(self): def callback(self, channel, body, envelope, properties): self.consume_future.set_result((body, envelope, properties)) + def callback_sync(self, channel, body, envelope, properties): + self.consume_future.set_result((body, envelope, properties)) + @asyncio.coroutine def get_callback_result(self): yield from self.consume_future @@ -86,6 +89,33 @@ def test_consume(self): self.assertEqual(b"coucou", body) self.assertIsInstance(properties, Properties) + @testing.coroutine + def test_consume_not_coroutine(self): + # declare + yield from self.channel.queue_declare("q", exclusive=True, no_wait=False) + yield from self.channel.exchange_declare("e", "fanout") + yield from self.channel.queue_bind("q", "e", routing_key='') + + # get a different channel + channel = yield from self.create_channel() + + # publish + yield from channel.publish("coucou", "e", routing_key='',) + + yield from asyncio.sleep(2, loop=self.loop) + # start consume + yield from channel.basic_consume(self.callback_sync, queue_name="q") + # required ? + yield from asyncio.sleep(2, loop=self.loop) + + self.assertTrue(self.consume_future.done()) + # get one + body, envelope, properties = yield from self.get_callback_result() + self.assertIsNotNone(envelope.consumer_tag) + self.assertIsNotNone(envelope.delivery_tag) + self.assertEqual(b"coucou", body) + self.assertIsInstance(properties, Properties) + @testing.coroutine def test_big_consume(self): # declare