Skip to content

Commit

Permalink
Add a timeout field to POST notebook
Browse files Browse the repository at this point in the history
This sets a timeout on the nbexec call to the JupyterLab server
extension for notebook execution. It raises a NbexecTaskTimeout
exception in the arq worker, which should result in an errored state for
the notebook execution job. We might need to fine tune the exception
handling so that the result output in the API is useful.
  • Loading branch information
jonathansick committed Sep 9, 2024
1 parent 556251a commit 9f5bfef
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
9 changes: 9 additions & 0 deletions src/noteburst/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
__all__ = [
"TaskError",
"NbexecTaskError",
"NbexecTaskTimeoutError",
"NoteburstClientRequestError",
"NoteburstError",
]
Expand Down Expand Up @@ -38,6 +39,14 @@ class NbexecTaskError(TaskError):
task_name = "nbexec"


class NbexecTaskTimeoutError(NbexecTaskError):
"""Error raised when a notebook execution task times out."""

@classmethod
def from_exception(cls, exc: Exception) -> Self:
return cls(f"{cls.task_name} timeout error\n\n{exc!s}")


class NoteburstClientRequestError(ClientRequestError):
"""Error related to the API client."""

Expand Down
1 change: 1 addition & 0 deletions src/noteburst/handlers/v1/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def post_nbexec(
ipynb=request_data.get_ipynb_as_str(),
kernel_name=request_data.kernel_name,
enable_retry=request_data.enable_retry,
timeout=request_data.timeout,
)
logger.info("Finished enqueing an nbexec task", job_id=job_metadata.id)
response_data = await NotebookResponse.from_job_metadata(
Expand Down
15 changes: 15 additions & 0 deletions src/noteburst/handlers/v1/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi import Request
from pydantic import AnyHttpUrl, BaseModel, Field
from safir.arq import JobMetadata, JobResult
from safir.pydantic import HumanTimedelta

from noteburst.jupyterclient.jupyterlab import (
NotebookExecutionErrorModel,
Expand Down Expand Up @@ -187,6 +188,20 @@ class PostNotebookRequest(BaseModel):
),
] = True

timeout: Annotated[
HumanTimedelta,
Field(
default_factory=HumanTimedelta("5m"),
title=("Timeout for notebook execution.",),
description=(
"The timeout is is a human-readable duration string. For "
"example, '5m' is 5 minutes, '1h' is 1 hour, '1d' is 1 day."
"If the notebook execution does not complete within this time,"
"the job is marked as failed."
),
),
]

def get_ipynb_as_str(self) -> str:
"""Get the ipynb as a JSON-encoded string."""
if isinstance(self.ipynb, str):
Expand Down
14 changes: 11 additions & 3 deletions src/noteburst/worker/functions/nbexec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

from __future__ import annotations

import asyncio
import json
import sys
from datetime import timedelta
from typing import Any, cast

from arq import Retry
from safir.slack.blockkit import SlackCodeBlock, SlackTextField

from noteburst.exceptions import NbexecTaskError
from noteburst.exceptions import NbexecTaskError, NbexecTaskTimeoutError
from noteburst.jupyterclient.jupyterlab import JupyterClient, JupyterError


Expand All @@ -21,6 +23,7 @@ async def nbexec(
ipynb: str,
kernel_name: str = "LSST",
enable_retry: bool = True,
timeout: timedelta | None = None, # noqa: ASYNC109
) -> str:
"""Execute a notebook, as an asynchronous arq worker task.
Expand Down Expand Up @@ -54,10 +57,15 @@ async def nbexec(
parsed_notebook = json.loads(ipynb)
logger.debug("Got ipynb", ipynb=parsed_notebook)
try:
execution_result = await jupyter_client.execute_notebook(
parsed_notebook, kernel_name=kernel_name
execution_result = await asyncio.wait_for(
jupyter_client.execute_notebook(
parsed_notebook, kernel_name=kernel_name
),
timeout=timeout.total_seconds() if timeout else None,
)
logger.info("nbexec finished", error=execution_result.error)
except TimeoutError as e:
raise NbexecTaskTimeoutError.from_exception(e) from e
except JupyterError as e:
logger.exception("nbexec error", jupyter_status=e.status)
if "slack" in ctx and "slack_message_factory" in ctx:
Expand Down

0 comments on commit 9f5bfef

Please sign in to comment.