diff --git a/docs/custom_service.md b/docs/custom_service.md index c18f6ce5f..dec12cbc4 100644 --- a/docs/custom_service.md +++ b/docs/custom_service.md @@ -123,6 +123,18 @@ Here the ``` handle()``` method is our entry point that will be invoked by MMS, This entry point is engaged in two cases: (1) when MMS is asked to scale a model up, to increase the number of backend workers (it is done either via a ```PUT /models/{model_name}``` request or a ```POST /models``` request with `initial-workers` option or during MMS startup when you use `--models` option (```multi-model-server --start --models {model_name=model.mar}```), ie., you provide model(s) to load) or (2) when MMS gets a ```POST /predictions/{model_name}``` request. (1) is used to scale-up or scale-down workers for a model. (2) is used as a standard way to run inference against a model. (1) is also known as model load time, and that is where you would normally want to put code for model initialization. You can find out more about these and other MMS APIs in [MMS Management API](./management_api.md) and [MMS Inference API](./inference_api.md) + +### Returning custom error codes + +To return a custom error code back to the user use the `PredictionException` in the `mms.service` module. + +```python +from mms.service import PredictionException +def handler(data, context): + # Some unexpected error - returning error code 513 + raise PredictionException("Some Prediction Error", 513) +``` + ## Creating model archive with entry point MMS, identifies the entry point to the custom service, from the manifest file. Thus file creating the model archive, one needs to mention the entry point using the ```--handler``` option. diff --git a/frontend/modelarchive/src/test/resources/models/custom-return-code/MAR-INF/MANIFEST.json b/frontend/modelarchive/src/test/resources/models/custom-return-code/MAR-INF/MANIFEST.json new file mode 100644 index 000000000..e0d80eb59 --- /dev/null +++ b/frontend/modelarchive/src/test/resources/models/custom-return-code/MAR-INF/MANIFEST.json @@ -0,0 +1,18 @@ +{ + "specificationVersion": "1.0", + "implementationVersion": "1.0", + "description": "noop v1.0", + "modelServerVersion": "1.0", + "license": "Apache 2.0", + "runtime": "python", + "model": { + "modelName": "pred-custom-return-code", + "description": "Tests for custom return code", + "modelVersion": "1.0", + "handler": "service:handle" + }, + "publisher": { + "author": "MXNet SDK team", + "email": "noreply@amazon.com" + } +} diff --git a/frontend/modelarchive/src/test/resources/models/custom-return-code/service.py b/frontend/modelarchive/src/test/resources/models/custom-return-code/service.py new file mode 100644 index 000000000..8ee9d1def --- /dev/null +++ b/frontend/modelarchive/src/test/resources/models/custom-return-code/service.py @@ -0,0 +1,18 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# http://www.apache.org/licenses/LICENSE-2.0 +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from mms.service import PredictionException + +def handle(data, ctx): + # Data is not none in prediction request + # Python raises PredictionException with custom error code + if data is not None: + raise PredictionException("Some Prediction Error", 599) + return ["OK"] diff --git a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java index 8aca550d6..92639696e 100644 --- a/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java +++ b/frontend/server/src/test/java/com/amazonaws/ml/mms/ModelServerTest.java @@ -183,6 +183,7 @@ public void test() testLoggingUnload(channel, managementChannel); testLoadingMemoryError(); testPredictionMemoryError(); + testPredictionCustomErrorCode(); testMetricManager(); testErrorBatch(); @@ -1189,6 +1190,51 @@ private void testPredictionMemoryError() throws InterruptedException { Assert.assertEquals(httpStatus, HttpResponseStatus.OK); } + private void testPredictionCustomErrorCode() throws InterruptedException { + // Load the model + Channel channel = connect(true); + Assert.assertNotNull(channel); + result = null; + latch = new CountDownLatch(1); + DefaultFullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.POST, + "/models?url=custom-return-code&model_name=custom-return-code&runtime=python&initial_workers=1&synchronous=true"); + channel.writeAndFlush(req); + latch.await(); + Assert.assertEquals(httpStatus, HttpResponseStatus.OK); + channel.close(); + + // Test for prediction + channel = connect(false); + Assert.assertNotNull(channel); + result = null; + latch = new CountDownLatch(1); + req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/custom-return-code"); + req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8); + + channel.writeAndFlush(req); + latch.await(); + + Assert.assertEquals(httpStatus.code(), 599); + channel.close(); + + // Unload the model + channel = connect(true); + httpStatus = null; + latch = new CountDownLatch(1); + Assert.assertNotNull(channel); + req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/custom-return-code"); + channel.writeAndFlush(req); + latch.await(); + Assert.assertEquals(httpStatus, HttpResponseStatus.OK); + } + private void testErrorBatch() throws InterruptedException { Channel channel = connect(true); Assert.assertNotNull(channel); diff --git a/mms/service.py b/mms/service.py index 8250e5556..82590c023 100644 --- a/mms/service.py +++ b/mms/service.py @@ -106,6 +106,9 @@ def predict(self, batch): # noinspection PyBroadException try: ret = self._entry_point(input_batch, self.context) + except PredictionException as e: + logger.error("Prediction error", exc_info=True) + return create_predict_response(None, req_id_map, e.message, e.error_code) except MemoryError: logger.error("System out of memory", exc_info=True) return create_predict_response(None, req_id_map, "Out of resources", 507) @@ -128,6 +131,16 @@ def predict(self, batch): return create_predict_response(ret, req_id_map, "Prediction success", 200, context=self.context) +class PredictionException(Exception): + def __init__(self, message, error_code=500): + self.message = message + self.error_code = error_code + super(PredictionException, self).__init__(message) + + def __str__(self): + return "message : error_code".format(message=self.message, error_code=self.error_code) + + def emit_metrics(metrics): """ Emit the metrics in the provided Dictionary