diff --git a/.github/workflows/airflow-operator.yml b/.github/workflows/airflow-operator.yml index 4868ed39b58..cc4704cfa56 100644 --- a/.github/workflows/airflow-operator.yml +++ b/.github/workflows/airflow-operator.yml @@ -16,7 +16,6 @@ on: - 'internal/jobservice/*' - 'pkg/api/*.proto' - 'pkg/api/jobservice/*.proto' - - 'scripts/build-airflow-operator.sh' - 'scripts/build-python-client.sh' - 'third_party/airflow/**' - './magefiles/tests.go' @@ -37,7 +36,6 @@ on: - 'internal/jobservice/*' - 'pkg/api/*.proto' - 'pkg/api/jobservice/*.proto' - - 'scripts/build-airflow-operator.sh' - 'scripts/build-python-client.sh' - 'third_party/airflow/**' diff --git a/.goreleaser.yml b/.goreleaser.yml index 229b42622e2..251fe7001da 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -195,7 +195,7 @@ dockers: - --label=org.opencontainers.image.version={{ .Version }} - --label=org.opencontainers.image.created={{ time "2006-01-02T15:04:05Z07:00" }} - --label=org.opencontainers.image.revision={{ .FullCommit }} - - --label=org.opencontainers.image.base.name=alpine:3.18.3 + - --label=org.opencontainers.image.base.name=alpine:3.20.1 - --label=org.opencontainers.image.licenses=Apache-2.0 - --label=org.opencontainers.image.vendor=G-Research ids: diff --git a/.mergify.yml b/.mergify.yml index 24f148f21c7..debf68207d5 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -9,6 +9,6 @@ pull_request_rules: - "#approved-reviews-by>=2" - and: - "#approved-reviews-by>=1" - - "author~=^(d80tb7|dave[-]gantenbein|dejanzele|JamesMurkin|msumner91|masipauskas|mijovicmia|MustafaI|zuqq|richscott|robertdavidsmith|samclark)" + - "author~=^(d80tb7|dave[-]gantenbein|dejanzele|JamesMurkin|msumner91|masipauskas|mijovicmia|MustafaI|zuqq|richscott|robertdavidsmith|samclark|suprjinx)" title: Two are checks required. diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 3099c5897f6..abfaed050a5 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -16,6 +16,7 @@ | Rich Scott | [richscott](https://github.com/richscott) | G-Research | | Robert Smith | [robertdavidsmith](https://github.com/robertdavidsmith) | G-Research | | Sam Clark | [samclark](https://github.com/samclark) | G-Research | +| Geoffrey Wilson | [suprjinx](https://github.com/suprjinx) | G-Research | ## Past @@ -26,6 +27,5 @@ | Carlo Camurri | [carlocamurri](https://github.com/carlocamurri) | G-Research | | Clifton Houck | [ClifHouck](https://github.com/ClifHouck) | G-Research | | Daniel Rastelli | [theAntiYeti](https://github.com/theAntiYeti) | G-Research | -| Geoffrey Wilson | [suprjinx](https://github.com/suprjinx) | G-Research | | Jamie Poole | [jimbobby5](https://github.com/jimbobby5) | G-Research | | Kevin Hannon | [kannon92](https://github.com/kannon92) | G-Research | diff --git a/build/airflow-operator/Dockerfile b/build/airflow-operator/Dockerfile index a3d774b30d6..87a2e81a5cb 100644 --- a/build/airflow-operator/Dockerfile +++ b/build/airflow-operator/Dockerfile @@ -1,5 +1,5 @@ ARG PLATFORM=x86_64 -ARG BASE_IMAGE=python:3.8.18-bookworm +ARG BASE_IMAGE=python:3.10.14-bookworm FROM --platform=$PLATFORM ${BASE_IMAGE} RUN mkdir /proto diff --git a/build/armada-load-tester/Dockerfile b/build/armada-load-tester/Dockerfile index 09b8b4aeac9..9ecee0f7061 100644 --- a/build/armada-load-tester/Dockerfile +++ b/build/armada-load-tester/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/armada/Dockerfile b/build/armada/Dockerfile index 6614890e768..7d15d88abb1 100644 --- a/build/armada/Dockerfile +++ b/build/armada/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/armadactl/Dockerfile b/build/armadactl/Dockerfile index 1fb97defb9e..1c0e5518eab 100644 --- a/build/armadactl/Dockerfile +++ b/build/armadactl/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/binoculars/Dockerfile b/build/binoculars/Dockerfile index 640fd53b986..a2d05ef191e 100644 --- a/build/binoculars/Dockerfile +++ b/build/binoculars/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/eventingester/Dockerfile b/build/eventingester/Dockerfile index ea77f3c9ca0..e1c29decfb0 100644 --- a/build/eventingester/Dockerfile +++ b/build/eventingester/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/executor/Dockerfile b/build/executor/Dockerfile index 9a139fffbed..99dfe76151a 100644 --- a/build/executor/Dockerfile +++ b/build/executor/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/fakeexecutor/Dockerfile b/build/fakeexecutor/Dockerfile index 8f822b59581..7e92d6b1ec5 100644 --- a/build/fakeexecutor/Dockerfile +++ b/build/fakeexecutor/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/jobservice/Dockerfile b/build/jobservice/Dockerfile index 1f5bc9a9af2..b9340243bfd 100644 --- a/build/jobservice/Dockerfile +++ b/build/jobservice/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/lookoutingesterv2/Dockerfile b/build/lookoutingesterv2/Dockerfile index f8128d0cc9a..c2ef341b15d 100644 --- a/build/lookoutingesterv2/Dockerfile +++ b/build/lookoutingesterv2/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/lookoutv2/Dockerfile b/build/lookoutv2/Dockerfile index 0d463389398..1a672055b0b 100644 --- a/build/lookoutv2/Dockerfile +++ b/build/lookoutv2/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/scheduler/Dockerfile b/build/scheduler/Dockerfile index b9cab04aebc..28a11f01b8e 100644 --- a/build/scheduler/Dockerfile +++ b/build/scheduler/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/scheduleringester/Dockerfile b/build/scheduleringester/Dockerfile index 810c76e9a01..817dd8d297c 100644 --- a/build/scheduleringester/Dockerfile +++ b/build/scheduleringester/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build/testsuite/Dockerfile b/build/testsuite/Dockerfile index b3a69121166..319eee90738 100644 --- a/build/testsuite/Dockerfile +++ b/build/testsuite/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.18.3 +FROM alpine:3.20.1 RUN addgroup -S -g 2000 armada && adduser -S -u 1000 armada -G armada diff --git a/build_goreleaser/armadactl/Dockerfile b/build_goreleaser/armadactl/Dockerfile index b286a5ae77d..80d6b6a66df 100644 --- a/build_goreleaser/armadactl/Dockerfile +++ b/build_goreleaser/armadactl/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=armadactl diff --git a/build_goreleaser/binoculars/Dockerfile b/build_goreleaser/binoculars/Dockerfile index a64955d0003..67e3c944ff1 100644 --- a/build_goreleaser/binoculars/Dockerfile +++ b/build_goreleaser/binoculars/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=binoculars LABEL org.opencontainers.image.description="binoculars" diff --git a/build_goreleaser/bundles/armada/Dockerfile b/build_goreleaser/bundles/armada/Dockerfile index 133a11e853f..3403e3b620f 100644 --- a/build_goreleaser/bundles/armada/Dockerfile +++ b/build_goreleaser/bundles/armada/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=armada LABEL org.opencontainers.image.description="Armada Bundle" diff --git a/build_goreleaser/bundles/full/Dockerfile b/build_goreleaser/bundles/full/Dockerfile index c4c09916783..a1e4ec658ae 100644 --- a/build_goreleaser/bundles/full/Dockerfile +++ b/build_goreleaser/bundles/full/Dockerfile @@ -1,6 +1,6 @@ ARG NODE_BUILD_IMAGE=node:16.14-buster ARG OPENAPI_BUILD_IMAGE=openapitools/openapi-generator-cli:v5.4.0 -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${OPENAPI_BUILD_IMAGE} AS OPENAPI LABEL org.opencontainers.image.title=armada-full-bundle diff --git a/build_goreleaser/bundles/lookout/Dockerfile b/build_goreleaser/bundles/lookout/Dockerfile index 3a180b7ce7b..ac1e173ca95 100644 --- a/build_goreleaser/bundles/lookout/Dockerfile +++ b/build_goreleaser/bundles/lookout/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=armada-lookout-bundle diff --git a/build_goreleaser/eventingester/Dockerfile b/build_goreleaser/eventingester/Dockerfile index dd15e0a8a3a..70431bb3967 100644 --- a/build_goreleaser/eventingester/Dockerfile +++ b/build_goreleaser/eventingester/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.5 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=eventingester diff --git a/build_goreleaser/executor/Dockerfile b/build_goreleaser/executor/Dockerfile index 36d7ceeb679..0b627d7f2ce 100644 --- a/build_goreleaser/executor/Dockerfile +++ b/build_goreleaser/executor/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=executor diff --git a/build_goreleaser/fakeexecutor/Dockerfile b/build_goreleaser/fakeexecutor/Dockerfile index d7fa88edb17..444cdc5afb0 100644 --- a/build_goreleaser/fakeexecutor/Dockerfile +++ b/build_goreleaser/fakeexecutor/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=fakeexecutor LABEL org.opencontainers.image.description="Fake Executor" diff --git a/build_goreleaser/jobservice/Dockerfile b/build_goreleaser/jobservice/Dockerfile index 9da0241774b..cb1540bbb7a 100644 --- a/build_goreleaser/jobservice/Dockerfile +++ b/build_goreleaser/jobservice/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=jobservice diff --git a/build_goreleaser/loadtester/Dockerfile b/build_goreleaser/loadtester/Dockerfile index e716e2fa7e1..5ecb9156aa7 100644 --- a/build_goreleaser/loadtester/Dockerfile +++ b/build_goreleaser/loadtester/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=loadtester diff --git a/build_goreleaser/lookoutingesterv2/Dockerfile b/build_goreleaser/lookoutingesterv2/Dockerfile index be74008b091..25595221e68 100644 --- a/build_goreleaser/lookoutingesterv2/Dockerfile +++ b/build_goreleaser/lookoutingesterv2/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=lookoutingesterv2 @@ -13,4 +13,4 @@ COPY config/lookoutingesterv2/config.yaml /app/config/lookoutingesterv2/config.y WORKDIR /app -ENTRYPOINT ["./lookoutingesterv2"] \ No newline at end of file +ENTRYPOINT ["./lookoutingesterv2"] diff --git a/build_goreleaser/lookoutv2/Dockerfile b/build_goreleaser/lookoutv2/Dockerfile index b3c07af4097..7d0e1ebca53 100644 --- a/build_goreleaser/lookoutv2/Dockerfile +++ b/build_goreleaser/lookoutv2/Dockerfile @@ -1,6 +1,6 @@ ARG NODE_BUILD_IMAGE=node:16.14-buster ARG OPENAPI_BUILD_IMAGE=openapitools/openapi-generator-cli:v5.4.0 -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${OPENAPI_BUILD_IMAGE} AS OPENAPI diff --git a/build_goreleaser/scheduler/Dockerfile b/build_goreleaser/scheduler/Dockerfile index 6922cd3be2e..24de8d5b69e 100644 --- a/build_goreleaser/scheduler/Dockerfile +++ b/build_goreleaser/scheduler/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=scheduler diff --git a/build_goreleaser/scheduleringester/Dockerfile b/build_goreleaser/scheduleringester/Dockerfile index 40a58a9e5b7..1d5096fc9ef 100644 --- a/build_goreleaser/scheduleringester/Dockerfile +++ b/build_goreleaser/scheduleringester/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=scheduleringester diff --git a/build_goreleaser/server/Dockerfile b/build_goreleaser/server/Dockerfile index 9568aa50aad..614ef8591e9 100644 --- a/build_goreleaser/server/Dockerfile +++ b/build_goreleaser/server/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=armada-server diff --git a/build_goreleaser/testsuite/Dockerfile b/build_goreleaser/testsuite/Dockerfile index 514c37566a8..c8dfccf4a95 100644 --- a/build_goreleaser/testsuite/Dockerfile +++ b/build_goreleaser/testsuite/Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=alpine:3.18.3 +ARG BASE_IMAGE=alpine:3.20.1 FROM ${BASE_IMAGE} LABEL org.opencontainers.image.title=testsuite LABEL org.opencontainers.image.description="Test Suite" diff --git a/cmd/armada-load-tester/cmd/loadtest.go b/cmd/armada-load-tester/cmd/loadtest.go index 126379afd4d..ad5a47eea85 100644 --- a/cmd/armada-load-tester/cmd/loadtest.go +++ b/cmd/armada-load-tester/cmd/loadtest.go @@ -52,7 +52,7 @@ var loadtestCmd = &cobra.Command{ containers: - name: sleep imagePullPolicy: IfNotPresent - image: alpine:3.18.3 + image: alpine:3.20.1 command: - sh args: diff --git a/config/armada/config.yaml b/config/armada/config.yaml index cb7e1b4df2a..950e7589308 100644 --- a/config/armada/config.yaml +++ b/config/armada/config.yaml @@ -75,6 +75,7 @@ pulsar: compressionLevel: faster eventsPrinter: false eventsPrinterSubscription: "EventsPrinter" + maxAllowedEventsPerMessage: 1000 maxAllowedMessageSize: 4194304 # 4MB receiverQueueSize: 100 postgres: diff --git a/config/eventingester/config.yaml b/config/eventingester/config.yaml index 292712bb25c..b75ce43c98b 100644 --- a/config/eventingester/config.yaml +++ b/config/eventingester/config.yaml @@ -12,9 +12,9 @@ pulsar: receiverQueueSize: 100 subscriptionName: "events-ingester" minMessageCompressionSize: 1024 -batchSize: 1048576 #1MB +maxOutputMessageSizeBytes: 1048576 #1MB +batchSize: 10000 batchDuration: 100ms -batchMessages: 10000 eventRetentionPolicy: retentionDuration: 336h metricsPort: 9001 diff --git a/config/scheduler/config.yaml b/config/scheduler/config.yaml index 07c432e1218..cd5d312dbdb 100644 --- a/config/scheduler/config.yaml +++ b/config/scheduler/config.yaml @@ -37,6 +37,7 @@ pulsar: maxConnectionsPerBroker: 1 compressionType: zlib compressionLevel: faster + maxAllowedEventsPerMessage: 1000 maxAllowedMessageSize: 4194304 #4Mi armadaApi: armadaUrl: "server:50051" @@ -86,6 +87,7 @@ scheduling: disableScheduling: false enableAssertions: false protectedFractionOfFairShare: 1.0 + useAdjustedFairShareProtection: true nodeIdLabel: "kubernetes.io/hostname" priorityClasses: armada-default: diff --git a/developer/config/job.yaml b/developer/config/job.yaml index aab2b8d257b..38906368930 100644 --- a/developer/config/job.yaml +++ b/developer/config/job.yaml @@ -11,7 +11,7 @@ jobs: containers: - name: sleep imagePullPolicy: IfNotPresent - image: alpine:3.10 + image: alpine:latest args: - "exit" - "1" @@ -26,4 +26,4 @@ jobs: timeout: "100s" expectedEvents: - submitted: - - failed: \ No newline at end of file + - failed: diff --git a/developer/env/docker/server.env b/developer/env/docker/server.env index 6b52f9b6342..a5b4496abe4 100644 --- a/developer/env/docker/server.env +++ b/developer/env/docker/server.env @@ -1,3 +1,3 @@ ARMADA_QUEUECACHEREFRESHPERIOD="1s" ARMADA_CORSALLOWEDORIGINS="http://localhost:3000,http://localhost:10000,http://example.com:10000" - +ARMADA_QUERYAPI_POSTGRES_CONNECTION_HOST=postgres diff --git a/docker-compose.yaml b/docker-compose.yaml index 4e96216ef27..68af1ce7aa2 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -61,6 +61,7 @@ services: depends_on: - lookoutv2-migration - eventingester + - lookoutingesterv2 working_dir: /app env_file: - developer/env/docker/server.env diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 048667a2562..665d6c0e82c 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -12,66 +12,61 @@ This class provides integration with Airflow and Armada ## armada.operators.armada module -### _class_ armada.operators.armada.ArmadaOperator(name, armada_channel_args, job_service_channel_args, armada_queue, job_request_items, lookout_url_template=None, poll_interval=30, \*\*kwargs) -Bases: `BaseOperator` +### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, \*\*kwargs) +Bases: `BaseOperator`, `LoggingMixin` -Implementation of an ArmadaOperator for airflow. +An Airflow operator that manages Job submission to Armada. -Airflow operators inherit from BaseOperator. +This operator submits a job to an Armada cluster, polls for its completion, +and handles job cancellation if the Airflow task is killed. * **Parameters** - * **name** (*str*) – The name of the airflow task + * **name** (*str*) – - * **armada_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. + * **channel_args** (*GrpcChannelArgs*) – - * **job_service_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. + * **armada_queue** (*str*) – - * **armada_queue** (*str*) – The queue name for Armada. + * **job_request** (*JobSubmitRequestItem*) – - * **job_request_items** (*List**[**JobSubmitRequestItem**]*) – A PodSpec that is used by Armada for submitting a job + * **job_set_prefix** (*str** | **None*) – - * **lookout_url_template** (*str** | **None*) – A URL template to be used to provide users - a valid link to the related lookout job in this operator’s log. - The format should be: - “[https://lookout.armada.domain/jobs](https://lookout.armada.domain/jobs)?job_id=” where will - be replaced with the actual job ID. + * **lookout_url_template** (*str** | **None*) – - * **poll_interval** (*int*) – How often to poll jobservice to get status. + * **poll_interval** (*int*) – + * **container_logs** (*str** | **None*) – -* **Returns** - an armada operator instance + * **k8s_token_retriever** (*TokenRetriever** | **None*) – + * **deferrable** (*bool*) – -#### execute(context) -Executes the Armada Operator. -Runs an Armada job and calls the job_service_client for polling. + * **job_acknowledgement_timeout** (*int*) – -* **Parameters** - **context** – The airflow context. +#### _property_ client(_: ArmadaClien_ ) +#### execute(context) +Submits the job to Armada and polls for completion. -* **Returns** +* **Parameters** - None + **context** (*Context*) – The execution context provided by Airflow. @@ -81,20 +76,11 @@ Runs an Armada job and calls the job_service_client for polling. -#### render_template_fields(context, jinja_env=None) -Template all attributes listed in *self.template_fields*. - -This mutates the attributes in-place and is irreversible. - - -* **Parameters** - - - * **context** (*Context*) – Context dict with values to apply on content. - - - * **jinja_env** (*Environment** | **None*) – Jinja’s environment to use for rendering. +#### on_kill() +Override this method to clean up subprocesses when a task instance gets killed. +Any use of the threading, subprocess or multiprocessing module within an +operator needs to be cleaned up, or it will leave ghost processes behind. * **Return type** @@ -103,133 +89,36 @@ This mutates the attributes in-place and is irreversible. -#### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) -## armada.operators.armada_deferrable module - - -### _class_ armada.operators.armada_deferrable.ArmadaDeferrableOperator(name, armada_channel_args, job_service_channel_args, armada_queue, job_request_items, lookout_url_template=None, poll_interval=30, \*\*kwargs) -Bases: `BaseOperator` - -Implementation of a deferrable armada operator for airflow. - -Distinguished from ArmadaOperator by its ability to defer itself after -submitting its job_request_items. - -See -[https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html](https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html) -for more information about deferrable airflow operators. - -Airflow operators inherit from BaseOperator. - - -* **Parameters** - - - * **name** (*str*) – The name of the airflow task. - - - * **armada_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. - - - * **job_service_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. - - - * **armada_queue** (*str*) – The queue name for Armada. - - - * **job_request_items** (*List**[**JobSubmitRequestItem**]*) – A PodSpec that is used by Armada for submitting a job. - - - * **lookout_url_template** (*str** | **None*) – A URL template to be used to provide users - a valid link to the related lookout job in this operator’s log. - The format should be: - “[https://lookout.armada.domain/jobs](https://lookout.armada.domain/jobs)?job_id=” where will - be replaced with the actual job ID. - - - * **poll_interval** (*int*) – How often to poll jobservice to get status. - - - -* **Returns** - - A deferrable armada operator instance. - - - -#### execute(context) -Executes the Armada Operator. Only meant to be called by airflow. - -Submits an Armada job and defers itself to ArmadaJobCompleteTrigger to wait -until the job completes. - +#### pod_manager(k8s_context) * **Parameters** - **context** – The airflow context. - - - -* **Returns** - - None + **k8s_context** (*str*) – * **Return type** - None + *PodLogManager* #### render_template_fields(context, jinja_env=None) -Template all attributes listed in *self.template_fields*. - +Template all attributes listed in self.template_fields. This mutates the attributes in-place and is irreversible. +Args: -* **Parameters** - - - * **context** (*Context*) – Context dict with values to apply on content. - - - * **jinja_env** (*Environment** | **None*) – Jinja’s environment to use for rendering. - - - -* **Return type** - - None - - - -#### resume_job_complete(context, event, job_id) -Resumes this operator after deferring itself to ArmadaJobCompleteTrigger. -Only meant to be called from within Airflow. - -Reports the result of the job and returns. + context (Context): The execution context provided by Airflow. * **Parameters** - * **context** – The airflow context. + * **context** (*Context*) – Airflow Context dict wi1th values to apply on content - * **event** (*dict*) – The payload from the TriggerEvent raised by - ArmadaJobCompleteTrigger. - - - * **job_id** (*str*) – The job ID. - - - -* **Returns** - - None + * **jinja_env** (*Environment** | **None*) – jinja’s environment to use for rendering. @@ -239,492 +128,52 @@ Reports the result of the job and returns. -#### serialize() -Get a serialized version of this object. - - -* **Returns** - - A dict of keyword arguments used when instantiating - - - -* **Return type** - - dict - - -this object. - - -#### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) - -### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30) -Bases: `BaseTrigger` - -An airflow trigger that monitors the job state of an armada job. - -Triggers when the job is complete. +#### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ ) +Initializes a new ArmadaOperator. * **Parameters** - * **job_id** (*str*) – The job ID to monitor. - - - * **job_service_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when - creating a grpc channel to connect to the job service instance. + * **name** (*str*) – The name of the job to be submitted. - * **armada_queue** (*str*) – The name of the armada queue. + * **channel_args** (*GrpcChannelArgs*) – The gRPC channel arguments for connecting to the Armada server. - * **job_set_id** (*str*) – The ID of the job set. + * **armada_queue** (*str*) – The name of the Armada queue to which the job will be submitted. - * **airflow_task_name** (*str*) – Name of the airflow task to which this trigger - belongs. + * **job_request** (*JobSubmitRequestItem*) – The job to be submitted to Armada. - * **poll_interval** (*int*) – How often to poll jobservice to get status. + * **job_set_prefix** (*Optional**[**str**]*) – A string to prepend to the jobSet name + * **lookout_url_template** – Template for creating lookout links. If not specified -* **Returns** - An armada job complete trigger instance. - - - -#### _async_ run() -Runs the trigger. Meant to be called by an airflow triggerer process. - - -#### serialize() -Return the information needed to reconstruct this Trigger. - - -* **Returns** - - Tuple of (class path, keyword arguments needed to re-instantiate). - - - -* **Return type** - - tuple +then no tracking information will be logged. +:type lookout_url_template: Optional[str] +:param poll_interval: The interval in seconds between polling for job status updates. +:type poll_interval: int +:param container_logs: Name of container whose logs will be published to stdout. +:type container_logs: Optional[str] +:param k8s_token_retriever: A serialisable Kubernetes token retriever object. We use +this to read logs from Kubernetes pods. +:type k8s_token_retriever: Optional[TokenRetriever] +:param deferrable: Whether the operator should run in a deferrable mode, allowing +for asynchronous execution. +:type deferrable: bool +:param job_acknowledgement_timeout: The timeout in seconds to wait for a job to be +acknowledged by Armada. +:type job_acknowledgement_timeout: int +:param kwargs: Additional keyword arguments to pass to the BaseOperator. +## armada.operators.armada_deferrable module ## armada.operators.jobservice module - -### _class_ armada.operators.jobservice.JobServiceClient(channel) -Bases: `object` - -The JobService Client - -Implementation of gRPC stubs from JobService - - -* **Parameters** - - **channel** – gRPC channel used for authentication. See - [https://grpc.github.io/grpc/python/grpc.html](https://grpc.github.io/grpc/python/grpc.html) - for more information. - - - -* **Returns** - - a job service client instance - - - -#### get_job_status(queue, job_set_id, job_id) -Get job status of a given job in a queue and job_set_id. - -Uses the GetJobStatus rpc to get a status of your job - - -* **Parameters** - - - * **queue** (*str*) – The name of the queue - - - * **job_set_id** (*str*) – The name of the job set (a grouping of jobs) - - - * **job_id** (*str*) – The id of the job - - - -* **Returns** - - A Job Service Request (State, Error) - - - -* **Return type** - - *JobServiceResponse* - - - -#### health() -Health Check for GRPC Request - - -* **Return type** - - *HealthCheckResponse* - - - -### armada.operators.jobservice.get_retryable_job_service_client(target, credentials=None, compression=None) -Get a JobServiceClient that has retry configured - - -* **Parameters** - - - * **target** (*str*) – grpc channel target - - - * **credentials** (*ChannelCredentials** | **None*) – grpc channel credentials (if needed) - - - * **compresion** – grpc channel compression - - - * **compression** (*Compression** | **None*) – - - - -* **Returns** - - A job service client instance - - - -* **Return type** - - *JobServiceClient* - - ## armada.operators.jobservice_asyncio module - -### _class_ armada.operators.jobservice_asyncio.JobServiceAsyncIOClient(channel) -Bases: `object` - -The JobService AsyncIO Client - -AsyncIO implementation of gRPC stubs from JobService - - -* **Parameters** - - **channel** (*Channel*) – AsyncIO gRPC channel used for authentication. See - [https://grpc.github.io/grpc/python/grpc_asyncio.html](https://grpc.github.io/grpc/python/grpc_asyncio.html) - for more information. - - - -* **Returns** - - A job service client instance - - - -#### _async_ get_job_status(queue, job_set_id, job_id) -Get job status of a given job in a queue and job_set_id. - -Uses the GetJobStatus rpc to get a status of your job - - -* **Parameters** - - - * **queue** (*str*) – The name of the queue - - - * **job_set_id** (*str*) – The name of the job set (a grouping of jobs) - - - * **job_id** (*str*) – The id of the job - - - -* **Returns** - - A Job Service Request (State, Error) - - - -* **Return type** - - *JobServiceResponse* - - - -#### _async_ health() -Health Check for GRPC Request - - -* **Return type** - - *HealthCheckResponse* - - - -### armada.operators.jobservice_asyncio.get_retryable_job_service_asyncio_client(target, credentials, compression) -Get a JobServiceAsyncIOClient that has retry configured - - -* **Parameters** - - - * **target** (*str*) – grpc channel target - - - * **credentials** (*ChannelCredentials** | **None*) – grpc channel credentials (if needed) - - - * **compresion** – grpc channel compression - - - * **compression** (*Compression** | **None*) – - - - -* **Returns** - - A job service asyncio client instance - - - -* **Return type** - - *JobServiceAsyncIOClient* - - ## armada.operators.utils module - - -### _class_ armada.operators.utils.JobState(value) -Bases: `Enum` - -An enumeration. - - -#### CANCELLED(_ = _ ) - -#### CONNECTION_ERR(_ = _ ) - -#### DUPLICATE_FOUND(_ = _ ) - -#### FAILED(_ = _ ) - -#### JOB_ID_NOT_FOUND(_ = _ ) - -#### RUNNING(_ = _ ) - -#### SUBMITTED(_ = _ ) - -#### SUCCEEDED(_ = _ ) - -### armada.operators.utils.airflow_error(job_state, name, job_id) -Throw an error on a terminal event if job errored out - - -* **Parameters** - - - * **job_state** (*JobState*) – A JobState enum class - - - * **name** (*str*) – The name of your armada job - - - * **job_id** (*str*) – The job id that armada assigns to it - - - -* **Returns** - - No Return or an AirflowFailException. - - -AirflowFailException tells Airflow Schedule to not reschedule the task - - -### armada.operators.utils.annotate_job_request_items(context, job_request_items) -Annotates the inbound job request items with Airflow context elements - - -* **Parameters** - - - * **context** – The airflow context. - - - * **job_request_items** (*List**[**JobSubmitRequestItem**]*) – The job request items to be sent to armada - - - -* **Returns** - - annotated job request items for armada - - - -* **Return type** - - *List*[*JobSubmitRequestItem*] - - - -### armada.operators.utils.default_job_status_callable(armada_queue, job_set_id, job_id, job_service_client) - -* **Parameters** - - - * **armada_queue** (*str*) – - - - * **job_set_id** (*str*) – - - - * **job_id** (*str*) – - - - * **job_service_client** (*JobServiceClient*) – - - - -* **Return type** - - *JobServiceResponse* - - - -### armada.operators.utils.get_annotation_key_prefix() -Provides the annotation key prefix, -which can be specified in env var ANNOTATION_KEY_PREFIX. -A default is provided if the env var is not defined - - -* **Returns** - - string annotation key prefix - - - -* **Return type** - - str - - - -### armada.operators.utils.job_state_from_pb(state) - -* **Return type** - - *JobState* - - - -### armada.operators.utils.search_for_job_complete(armada_queue, job_set_id, airflow_task_name, job_id, poll_interval=30, job_service_client=None, job_status_callable=, time_out_for_failure=7200) -Poll JobService cache until you get a terminated event. - -A terminated event is SUCCEEDED, FAILED or CANCELLED - - -* **Parameters** - - - * **armada_queue** (*str*) – The queue for armada - - - * **job_set_id** (*str*) – Your job_set_id - - - * **airflow_task_name** (*str*) – The name of your armada job - - - * **poll_interval** (*int*) – Polling interval for jobservice to get status. - - - * **job_id** (*str*) – The name of the job id that armada assigns to it - - - * **job_service_client** (*JobServiceClient** | **None*) – A JobServiceClient that is used for polling. - It is optional only for testing - - - * **job_status_callable** – A callable object for test injection. - - - * **time_out_for_failure** (*int*) – The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - - - -* **Returns** - - A tuple of JobStateEnum, message - - - -* **Return type** - - *Tuple*[*JobState*, str] - - - -### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200) -Poll JobService cache asyncronously until you get a terminated event. - -A terminated event is SUCCEEDED, FAILED or CANCELLED - - -* **Parameters** - - - * **armada_queue** (*str*) – The queue for armada - - - * **job_set_id** (*str*) – Your job_set_id - - - * **airflow_task_name** (*str*) – The name of your armada job - - - * **job_id** (*str*) – The name of the job id that armada assigns to it - - - * **job_service_client** (*JobServiceAsyncIOClient*) – A JobServiceClient that is used for polling. - It is optional only for testing - - - * **poll_interval** (*int*) – How often to poll jobservice to get status. - - - * **time_out_for_failure** (*int*) – The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - - - -* **Returns** - - A tuple of JobStateEnum, message - - - -* **Return type** - - *Tuple*[*JobState*, str] diff --git a/e2e/armadactl_test/armadactl_test.go b/e2e/armadactl_test/armadactl_test.go index f5ee2fd956d..419b4eaa114 100644 --- a/e2e/armadactl_test/armadactl_test.go +++ b/e2e/armadactl_test/armadactl_test.go @@ -175,7 +175,7 @@ jobs: containers: - name: ls imagePullPolicy: IfNotPresent - image: alpine:3.18.3 + image: alpine:3.20.1 command: - sh - -c diff --git a/internal/armada/configuration/types.go b/internal/armada/configuration/types.go index 575ea21f77a..b1a3c5c9d5f 100644 --- a/internal/armada/configuration/types.go +++ b/internal/armada/configuration/types.go @@ -66,6 +66,8 @@ type PulsarConfig struct { CompressionLevel pulsar.CompressionLevel // Settings for deduplication, which relies on a postgres server. DedupTable string + // Maximum allowed Events per message + MaxAllowedEventsPerMessage int `validate:"gte=0"` // Maximum allowed message size in bytes MaxAllowedMessageSize uint // Timeout when polling pulsar for messages diff --git a/internal/armada/server.go b/internal/armada/server.go index 86fbabfae21..f24d1be5750 100644 --- a/internal/armada/server.go +++ b/internal/armada/server.go @@ -128,7 +128,7 @@ func Serve(ctx *armadacontext.Context, config *configuration.ArmadaConfig, healt CompressionLevel: config.Pulsar.CompressionLevel, BatchingMaxSize: config.Pulsar.MaxAllowedMessageSize, Topic: config.Pulsar.JobsetEventsTopic, - }, config.Pulsar.MaxAllowedMessageSize) + }, config.Pulsar.MaxAllowedEventsPerMessage, config.Pulsar.MaxAllowedMessageSize) if err != nil { return errors.Wrapf(err, "error creating pulsar producer") } diff --git a/internal/common/eventutil/eventutil.go b/internal/common/eventutil/eventutil.go index c83786f868f..9ca907556b3 100644 --- a/internal/common/eventutil/eventutil.go +++ b/internal/common/eventutil/eventutil.go @@ -14,6 +14,7 @@ import ( "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/armadaerrors" + "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/pkg/api" "github.com/armadaproject/armada/pkg/armadaevents" ) @@ -253,6 +254,29 @@ func groupsEqual(g1, g2 []string) bool { return true } +func LimitSequencesEventMessageCount(sequences []*armadaevents.EventSequence, maxEventsPerSequence int) []*armadaevents.EventSequence { + rv := make([]*armadaevents.EventSequence, 0, len(sequences)) + for _, sequence := range sequences { + if len(sequence.Events) > maxEventsPerSequence { + splitEventMessages := slices.PartitionToMaxLen(sequence.Events, maxEventsPerSequence) + + for _, eventMessages := range splitEventMessages { + rv = append(rv, &armadaevents.EventSequence{ + Queue: sequence.Queue, + JobSetName: sequence.JobSetName, + UserId: sequence.UserId, + Groups: sequence.Groups, + Events: eventMessages, + }) + } + + } else { + rv = append(rv, sequence) + } + } + return rv +} + // LimitSequencesByteSize calls LimitSequenceByteSize for each of the provided sequences // and returns all resulting sequences. func LimitSequencesByteSize(sequences []*armadaevents.EventSequence, sizeInBytes uint, strict bool) ([]*armadaevents.EventSequence, error) { diff --git a/internal/common/eventutil/eventutil_test.go b/internal/common/eventutil/eventutil_test.go index e1015b00359..3d68c4c8f86 100644 --- a/internal/common/eventutil/eventutil_test.go +++ b/internal/common/eventutil/eventutil_test.go @@ -286,6 +286,68 @@ func TestSequenceEventListSizeBytes(t *testing.T) { assert.True(t, sequenceSizeBytes < sequenceEventListOverheadSizeBytes) } +func TestLimitSequencesEventMessageCount(t *testing.T) { + input := []*armadaevents.EventSequence{ + { + Queue: "queue1", + UserId: "userId1", + JobSetName: "jobSetName1", + Groups: []string{"group1", "group2"}, + Events: []*armadaevents.EventSequence_Event{ + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "a"}}}, + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "b"}}}, + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "c"}}}, + }, + }, + { + Queue: "queue2", + UserId: "userId1", + JobSetName: "jobSetName1", + Groups: []string{"group1", "group2"}, + Events: []*armadaevents.EventSequence_Event{ + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "d"}}}, + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "e"}}}, + }, + }, + } + + expected := []*armadaevents.EventSequence{ + { + Queue: "queue1", + UserId: "userId1", + JobSetName: "jobSetName1", + Groups: []string{"group1", "group2"}, + Events: []*armadaevents.EventSequence_Event{ + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "a"}}}, + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "b"}}}, + }, + }, + { + Queue: "queue1", + UserId: "userId1", + JobSetName: "jobSetName1", + Groups: []string{"group1", "group2"}, + Events: []*armadaevents.EventSequence_Event{ + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "c"}}}, + }, + }, + { + Queue: "queue2", + UserId: "userId1", + JobSetName: "jobSetName1", + Groups: []string{"group1", "group2"}, + Events: []*armadaevents.EventSequence_Event{ + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "d"}}}, + {Event: &armadaevents.EventSequence_Event_SubmitJob{SubmitJob: &armadaevents.SubmitJob{JobIdStr: "e"}}}, + }, + }, + } + + result := LimitSequencesEventMessageCount(input, 2) + assert.Len(t, result, 3) + assert.Equal(t, expected, result) +} + func TestLimitSequenceByteSize(t *testing.T) { sequence := &armadaevents.EventSequence{ Queue: "queue1", diff --git a/internal/common/ingest/batch.go b/internal/common/ingest/batch.go index 52284e41b90..bbd1f8a574c 100644 --- a/internal/common/ingest/batch.go +++ b/internal/common/ingest/batch.go @@ -17,25 +17,32 @@ type Batcher[T any] struct { maxItems int maxTimeout time.Duration clock clock.Clock - callback func([]T) - buffer []T - mutex sync.Mutex + // This function is used to determine how many items are in a given input + // This allows customising how the batcher batches up your input + // Such as if you are batching objects A, but want to limit on the number of A.[]B objects seen + // In which case this function should return len(A.[]B) + itemCountFunc func(T) int + callback func([]T) + buffer []T + mutex sync.Mutex } -func NewBatcher[T any](input <-chan T, maxItems int, maxTimeout time.Duration, callback func([]T)) *Batcher[T] { +func NewBatcher[T any](input <-chan T, maxItems int, maxTimeout time.Duration, itemCountFunc func(T) int, callback func([]T)) *Batcher[T] { return &Batcher[T]{ - input: input, - maxItems: maxItems, - maxTimeout: maxTimeout, - callback: callback, - clock: clock.RealClock{}, - mutex: sync.Mutex{}, + input: input, + maxItems: maxItems, + maxTimeout: maxTimeout, + itemCountFunc: itemCountFunc, + callback: callback, + clock: clock.RealClock{}, + mutex: sync.Mutex{}, } } func (b *Batcher[T]) Run(ctx *armadacontext.Context) { for { b.buffer = []T{} + totalNumberOfItems := 0 expire := b.clock.After(b.maxTimeout) for appendToBatch := true; appendToBatch; { select { @@ -50,7 +57,8 @@ func (b *Batcher[T]) Run(ctx *armadacontext.Context) { } b.mutex.Lock() b.buffer = append(b.buffer, value) - if len(b.buffer) == b.maxItems { + totalNumberOfItems += b.itemCountFunc(value) + if totalNumberOfItems >= b.maxItems { b.callback(b.buffer) appendToBatch = false } diff --git a/internal/common/ingest/batch_test.go b/internal/common/ingest/batch_test.go index 3160303dd54..eb1a07fbcbb 100644 --- a/internal/common/ingest/batch_test.go +++ b/internal/common/ingest/batch_test.go @@ -17,6 +17,8 @@ const ( defaultMaxTimeOut = 5 * time.Second ) +var defaultItemCounterFunc = func(i int) int { return 1 } + type resultHolder struct { result [][]int mutex sync.Mutex @@ -46,15 +48,15 @@ func TestBatch_MaxItems(t *testing.T) { testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() - batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, result.add) + batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, defaultItemCounterFunc, result.add) batcher.clock = testClock go func() { batcher.Run(ctx) }() - // Post 3 items on the input channel without advancing the clock - // And we should get a single update on the output channel + // Post 6 items on the input channel without advancing the clock + // And we should get a 2 updates on the output channel inputChan <- 1 inputChan <- 2 inputChan <- 3 @@ -66,12 +68,39 @@ func TestBatch_MaxItems(t *testing.T) { cancel() } +func TestBatch_MaxItems_CustomItemCountFunction(t *testing.T) { + ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) + testClock := clock.NewFakeClock(time.Now()) + inputChan := make(chan int) + result := newResultHolder() + // This function will mean each item on the input channel will count as 2 items + doubleItemCounterFunc := func(i int) int { return 2 } + batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, doubleItemCounterFunc, result.add) + batcher.clock = testClock + + go func() { + batcher.Run(ctx) + }() + + // Post 6 items on the input channel without advancing the clock + // And we should get a 3 updates on the output channel + inputChan <- 1 + inputChan <- 2 + inputChan <- 3 + inputChan <- 4 + inputChan <- 5 + inputChan <- 6 + waitForExpectedEvents(ctx, result, 3) + assert.Equal(t, [][]int{{1, 2}, {3, 4}, {5, 6}}, result.result) + cancel() +} + func TestBatch_Time(t *testing.T) { ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second) testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() - batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, result.add) + batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, defaultItemCounterFunc, result.add) batcher.clock = testClock go func() { @@ -93,7 +122,7 @@ func TestBatch_Time_WithIntialQuiet(t *testing.T) { testClock := clock.NewFakeClock(time.Now()) inputChan := make(chan int) result := newResultHolder() - batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, result.add) + batcher := NewBatcher[int](inputChan, defaultMaxItems, defaultMaxTimeOut, defaultItemCounterFunc, result.add) batcher.clock = testClock go func() { diff --git a/internal/common/ingest/ingestion_pipeline.go b/internal/common/ingest/ingestion_pipeline.go index 136d55fbf65..404372ef859 100644 --- a/internal/common/ingest/ingestion_pipeline.go +++ b/internal/common/ingest/ingestion_pipeline.go @@ -109,31 +109,72 @@ func (i *IngestionPipeline[T]) Run(ctx *armadacontext.Context) error { i.consumer = consumer defer closePulsar() } - pulsarMsgs := i.consumer.Chan() + pulsarMessageChannel := i.consumer.Chan() + pulsarMessages := make(chan pulsar.ConsumerMessage) - // Batch up messages - batchedMsgs := make(chan []pulsar.ConsumerMessage) - batcher := NewBatcher[pulsar.ConsumerMessage](pulsarMsgs, i.pulsarBatchSize, i.pulsarBatchDuration, func(b []pulsar.ConsumerMessage) { batchedMsgs <- b }) + // Consume pulsar messages + // Used to track if we are no longer receiving pulsar messages go func() { - batcher.Run(ctx) - close(batchedMsgs) + timeout := time.Minute * 2 + timer := time.NewTimer(timeout) + loop: + for { + if !timer.Stop() { + <-timer.C + } + timer.Reset(timeout) + select { + case msg, ok := <-pulsarMessageChannel: + if !ok { + // Channel closed + break loop + } + pulsarMessages <- msg + case <-timer.C: + log.Infof("No pulsar message received in %s", timeout) + } + } + close(pulsarMessages) }() // Convert to event sequences eventSequences := make(chan *EventSequencesWithIds) go func() { - for msg := range batchedMsgs { + for msg := range pulsarMessages { converted := unmarshalEventSequences(msg, i.metrics) eventSequences <- converted } close(eventSequences) }() + // Batch up messages + batchedEventSequences := make(chan *EventSequencesWithIds) + eventCounterFunc := func(seq *EventSequencesWithIds) int { return len(seq.EventSequences) } + eventPublisherFunc := func(b []*EventSequencesWithIds) { batchedEventSequences <- combineEventSequences(b) } + batcher := NewBatcher[*EventSequencesWithIds](eventSequences, i.pulsarBatchSize, i.pulsarBatchDuration, eventCounterFunc, eventPublisherFunc) + go func() { + batcher.Run(ctx) + close(batchedEventSequences) + }() + + // Log summary of batch + preprocessedBatchEventSequences := make(chan *EventSequencesWithIds) + go func() { + for msg := range batchedEventSequences { + logSummaryOfEventSequences(msg) + preprocessedBatchEventSequences <- msg + } + close(preprocessedBatchEventSequences) + }() + // Convert to instructions instructions := make(chan T) go func() { - for msg := range eventSequences { + for msg := range preprocessedBatchEventSequences { + start := time.Now() converted := i.converter.Convert(ctx, msg) + taken := time.Now().Sub(start) + log.Infof("Processed %d pulsar messages in %dms", len(msg.MessageIds), taken.Milliseconds()) instructions <- converted } close(instructions) @@ -202,23 +243,20 @@ func (i *IngestionPipeline[T]) subscribe() (pulsar.Consumer, func(), error) { }, nil } -func unmarshalEventSequences(batch []pulsar.ConsumerMessage, metrics *commonmetrics.Metrics) *EventSequencesWithIds { - sequences := make([]*armadaevents.EventSequence, 0, len(batch)) - messageIds := make([]pulsar.MessageID, len(batch)) - for i, msg := range batch { +func unmarshalEventSequences(msg pulsar.ConsumerMessage, metrics *commonmetrics.Metrics) *EventSequencesWithIds { + sequences := make([]*armadaevents.EventSequence, 0, 1) + messageIds := make([]pulsar.MessageID, 0, 1) - // Record the messageId- we need to record all message Ids, even if the event they contain is invalid - // As they must be acked at the end - messageIds[i] = msg.ID() - - // Try and unmarshall the proto - es, err := eventutil.UnmarshalEventSequence(armadacontext.Background(), msg.Payload()) - if err != nil { - metrics.RecordPulsarMessageError(commonmetrics.PulsarMessageErrorDeserialization) - log.WithError(err).Warnf("Could not unmarshal proto for msg %s", msg.ID()) - continue - } + // Record the messageId- we need to record all message Ids, even if the event they contain is invalid + // As they must be acked at the end + messageIds = append(messageIds, msg.ID()) + // Try and unmarshall the proto + es, err := eventutil.UnmarshalEventSequence(armadacontext.Background(), msg.Payload()) + if err != nil { + metrics.RecordPulsarMessageError(commonmetrics.PulsarMessageErrorDeserialization) + log.WithError(err).Warnf("Could not unmarshal proto for msg %s", msg.ID()) + } else { // Fill in time if it is not set // TODO - once created is set everywhere we can remove this for _, event := range es.Events { @@ -233,3 +271,28 @@ func unmarshalEventSequences(batch []pulsar.ConsumerMessage, metrics *commonmetr EventSequences: sequences, MessageIds: messageIds, } } + +func combineEventSequences(sequences []*EventSequencesWithIds) *EventSequencesWithIds { + combinedSequences := make([]*armadaevents.EventSequence, 0) + messageIds := []pulsar.MessageID{} + for _, seq := range sequences { + combinedSequences = append(combinedSequences, seq.EventSequences...) + messageIds = append(messageIds, seq.MessageIds...) + } + return &EventSequencesWithIds{ + EventSequences: combinedSequences, MessageIds: messageIds, + } +} + +func logSummaryOfEventSequences(sequence *EventSequencesWithIds) { + numberOfEvents := 0 + countOfEventsByType := map[string]int{} + for _, eventSequence := range sequence.EventSequences { + numberOfEvents += len(eventSequence.Events) + for _, e := range eventSequence.Events { + typeString := e.GetEventName() + countOfEventsByType[typeString] = countOfEventsByType[typeString] + 1 + } + } + log.Infof("Batch being processed contains %d event messages and %d events of type %v", len(sequence.MessageIds), numberOfEvents, countOfEventsByType) +} diff --git a/internal/common/pulsarutils/eventsequence.go b/internal/common/pulsarutils/eventsequence.go index 981aa824e09..270f6779837 100644 --- a/internal/common/pulsarutils/eventsequence.go +++ b/internal/common/pulsarutils/eventsequence.go @@ -16,10 +16,12 @@ import ( // CompactAndPublishSequences reduces the number of sequences to the smallest possible, // while respecting per-job set ordering and max Pulsar message size, and then publishes to Pulsar. -func CompactAndPublishSequences(ctx *armadacontext.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxMessageSizeInBytes uint) error { +func CompactAndPublishSequences(ctx *armadacontext.Context, sequences []*armadaevents.EventSequence, producer pulsar.Producer, maxEventsPerMessage int, maxMessageSizeInBytes uint) error { // Reduce the number of sequences to send to the minimum possible, - // and then break up any sequences larger than maxMessageSizeInBytes. sequences = eventutil.CompactEventSequences(sequences) + // Limit each sequence to have no more than maxEventsPerSequence events per sequence + sequences = eventutil.LimitSequencesEventMessageCount(sequences, maxEventsPerMessage) + // Limit each sequence to be no larger than maxMessageSizeInBytes bytes sequences, err := eventutil.LimitSequencesByteSize(sequences, maxMessageSizeInBytes, true) if err != nil { return err diff --git a/internal/common/pulsarutils/publisher.go b/internal/common/pulsarutils/publisher.go index 19213cce335..f4d32e84005 100644 --- a/internal/common/pulsarutils/publisher.go +++ b/internal/common/pulsarutils/publisher.go @@ -18,6 +18,8 @@ type Publisher interface { type PulsarPublisher struct { // Used to send messages to pulsar producer pulsar.Producer + // Maximum number of Events in each EventSequence + maxEventsPerMessage int // Maximum size (in bytes) of produced pulsar messages. // This must be below 4MB which is the pulsar message size limit maxAllowedMessageSize uint @@ -26,6 +28,7 @@ type PulsarPublisher struct { func NewPulsarPublisher( pulsarClient pulsar.Client, producerOptions pulsar.ProducerOptions, + maxEventsPerMessage int, maxAllowedMessageSize uint, ) (*PulsarPublisher, error) { producer, err := pulsarClient.CreateProducer(producerOptions) @@ -34,6 +37,7 @@ func NewPulsarPublisher( } return &PulsarPublisher{ producer: producer, + maxEventsPerMessage: maxEventsPerMessage, maxAllowedMessageSize: maxAllowedMessageSize, }, nil } @@ -45,6 +49,7 @@ func (p *PulsarPublisher) PublishMessages(ctx *armadacontext.Context, es *armada ctx, []*armadaevents.EventSequence{es}, p.producer, + p.maxEventsPerMessage, p.maxAllowedMessageSize) } diff --git a/internal/eventingester/configuration/types.go b/internal/eventingester/configuration/types.go index d2d4d6ad1a4..65f67a2ad32 100644 --- a/internal/eventingester/configuration/types.go +++ b/internal/eventingester/configuration/types.go @@ -19,9 +19,9 @@ type EventIngesterConfiguration struct { SubscriptionName string // Size in bytes above which event message will be compressed when inserting in the database MinMessageCompressionSize int + // Max size in bytes that messages inserted into the database will be + MaxOutputMessageSizeBytes int // Number of messages that will be batched together before being inserted into the database - BatchMessages int - // Size of messages that will be batched together before being inserted into the database BatchSize int // Maximum time since the last batch before a batch will be inserted into the database BatchDuration time.Duration diff --git a/internal/eventingester/ingester.go b/internal/eventingester/ingester.go index 46426733d28..e47a7eb1afa 100644 --- a/internal/eventingester/ingester.go +++ b/internal/eventingester/ingester.go @@ -64,7 +64,7 @@ func Run(config *configuration.EventIngesterConfiguration) { log.Errorf("Error creating compressor for consumer") panic(err) } - converter := convert.NewEventConverter(compressor, uint(config.BatchSize), metrics) + converter := convert.NewEventConverter(compressor, uint(config.MaxOutputMessageSizeBytes), metrics) ingester := ingest.NewIngestionPipeline( config.Pulsar, diff --git a/internal/lookoutingesterv2/configuration/types.go b/internal/lookoutingesterv2/configuration/types.go index c9540d3b21d..b6b6b154b7f 100644 --- a/internal/lookoutingesterv2/configuration/types.go +++ b/internal/lookoutingesterv2/configuration/types.go @@ -17,7 +17,7 @@ type LookoutIngesterV2Configuration struct { SubscriptionName string // Size in bytes above which job specs will be compressed when inserting in the database MinJobSpecCompressionSize int - // Number of messages that will be batched together before being inserted into the database + // Number of event messages that will be batched together before being inserted into the database BatchSize int // Maximum time since the last batch before a batch will be inserted into the database BatchDuration time.Duration diff --git a/internal/lookoutingesterv2/lookoutdb/insertion_test.go b/internal/lookoutingesterv2/lookoutdb/insertion_test.go index d53db5da501..0806c841c83 100644 --- a/internal/lookoutingesterv2/lookoutdb/insertion_test.go +++ b/internal/lookoutingesterv2/lookoutdb/insertion_test.go @@ -611,6 +611,14 @@ func TestCreateJobErrorsBatch(t *testing.T) { jobError = getJobError(t, db, jobIdString) assert.Equal(t, expectedJobError, jobError) + // Insert again with a different value and check we don't overwrite + jobErrors := defaultInstructionSet().JobErrorsToCreate + jobErrors[0].Error = []byte{} + err = ldb.CreateJobErrorsBatch(armadacontext.Background(), jobErrors) + assert.Nil(t, err) + jobError = getJobError(t, db, jobIdString) + assert.Equal(t, expectedJobError, jobError) + // If a row is bad then we should return an error and no updates should happen _, err = ldb.db.Exec(armadacontext.Background(), "DELETE FROM job_error") assert.NoError(t, err) diff --git a/internal/lookoutv2/application.go b/internal/lookoutv2/application.go index 3868d7b1af1..47385dc9414 100644 --- a/internal/lookoutv2/application.go +++ b/internal/lookoutv2/application.go @@ -127,7 +127,7 @@ func Serve(configuration configuration.LookoutV2Config) error { ctx := armadacontext.New(params.HTTPRequest.Context(), logger) result, err := getJobErrorRepo.GetJobErrorMessage(ctx, params.GetJobErrorRequest.JobID) if err != nil { - return operations.NewGetJobRunDebugMessageBadRequest().WithPayload(conversions.ToSwaggerError(err.Error())) + return operations.NewGetJobErrorBadRequest().WithPayload(conversions.ToSwaggerError(err.Error())) } return operations.NewGetJobErrorOK().WithPayload(&operations.GetJobErrorOKBody{ ErrorString: result, diff --git a/internal/scheduler/api.go b/internal/scheduler/api.go index 01259561597..f455150d274 100644 --- a/internal/scheduler/api.go +++ b/internal/scheduler/api.go @@ -40,6 +40,8 @@ type ExecutorApi struct { allowedPriorities []int32 // Known priority classes priorityClasses map[string]priorityTypes.PriorityClass + // Max number of events in published Pulsar messages + maxEventsPerPulsarMessage int // Max size of Pulsar messages produced. maxPulsarMessageSizeBytes uint // See scheduling schedulingConfig. @@ -56,6 +58,7 @@ func NewExecutorApi(producer pulsar.Producer, nodeIdLabel string, priorityClassNameOverride *string, priorityClasses map[string]priorityTypes.PriorityClass, + maxEventsPerPulsarMessage int, maxPulsarMessageSizeBytes uint, ) (*ExecutorApi, error) { if len(allowedPriorities) == 0 { @@ -66,6 +69,7 @@ func NewExecutorApi(producer pulsar.Producer, jobRepository: jobRepository, executorRepository: executorRepository, allowedPriorities: allowedPriorities, + maxEventsPerPulsarMessage: maxEventsPerPulsarMessage, maxPulsarMessageSizeBytes: maxPulsarMessageSizeBytes, nodeIdLabel: nodeIdLabel, priorityClassNameOverride: priorityClassNameOverride, @@ -310,7 +314,7 @@ func addAnnotations(job *armadaevents.SubmitJob, annotations map[string]string) // ReportEvents publishes all eventSequences to Pulsar. The eventSequences are compacted for more efficient publishing. func (srv *ExecutorApi) ReportEvents(grpcCtx context.Context, list *executorapi.EventList) (*types.Empty, error) { ctx := armadacontext.FromGrpcCtx(grpcCtx) - err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxPulsarMessageSizeBytes) + err := pulsarutils.CompactAndPublishSequences(ctx, list.Events, srv.producer, srv.maxEventsPerPulsarMessage, srv.maxPulsarMessageSizeBytes) return &types.Empty{}, err } diff --git a/internal/scheduler/api_test.go b/internal/scheduler/api_test.go index 343c1896591..022aae6a2c4 100644 --- a/internal/scheduler/api_test.go +++ b/internal/scheduler/api_test.go @@ -324,6 +324,7 @@ func TestExecutorApi_LeaseJobRuns(t *testing.T) { "kubernetes.io/hostname", nil, priorityClasses, + 1000, 4*1024*1024, ) require.NoError(t, err) @@ -450,6 +451,7 @@ func TestExecutorApi_Publish(t *testing.T) { "kubernetes.io/hostname", nil, priorityClasses, + 1000, 4*1024*1024, ) diff --git a/internal/scheduler/configuration/configuration.go b/internal/scheduler/configuration/configuration.go index 9e3cf36c04e..54800d0e7d9 100644 --- a/internal/scheduler/configuration/configuration.go +++ b/internal/scheduler/configuration/configuration.go @@ -153,6 +153,8 @@ type SchedulingConfig struct { EnableAssertions bool // Only queues allocated more than this fraction of their fair share are considered for preemption. ProtectedFractionOfFairShare float64 `validate:"gte=0"` + // Use Max(AdjustedFairShare, FairShare) for fair share protection. If false then FairShare will be used. + UseAdjustedFairShareProtection bool // Armada adds a node selector term to every scheduled pod using this label with the node name as value. // This to force kube-scheduler to schedule pods on the node chosen by Armada. // For example, if NodeIdLabel is "kubernetes.io/hostname" and armada schedules a pod on node "myNode", diff --git a/internal/scheduler/context/context.go b/internal/scheduler/context/context.go index 865c96955e9..b4bc129a8ce 100644 --- a/internal/scheduler/context/context.go +++ b/internal/scheduler/context/context.go @@ -20,6 +20,7 @@ import ( armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" + schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" "github.com/armadaproject/armada/internal/scheduler/fairness" "github.com/armadaproject/armada/internal/scheduler/interfaces" "github.com/armadaproject/armada/internal/scheduler/internaltypes" @@ -109,6 +110,7 @@ func (sctx *SchedulingContext) ClearUnfeasibleSchedulingKeys() { func (sctx *SchedulingContext) AddQueueSchedulingContext( queue string, weight float64, initialAllocatedByPriorityClass schedulerobjects.QuantityByTAndResourceType[string], + demand schedulerobjects.ResourceList, limiter *rate.Limiter, ) error { if _, ok := sctx.QueueSchedulingContexts[queue]; ok { @@ -136,6 +138,7 @@ func (sctx *SchedulingContext) AddQueueSchedulingContext( Weight: weight, Limiter: limiter, Allocated: allocated, + Demand: demand, AllocatedByPriorityClass: initialAllocatedByPriorityClass, ScheduledResourcesByPriorityClass: make(schedulerobjects.QuantityByTAndResourceType[string]), EvictedResourcesByPriorityClass: make(schedulerobjects.QuantityByTAndResourceType[string]), @@ -166,6 +169,73 @@ func (sctx *SchedulingContext) TotalCost() float64 { return rv } +// UpdateFairShares updates FairShare and AdjustedFairShare for every QueueSchedulingContext associated with the +// SchedulingContext. This works by calculating a far share as queue_weight/sum_of_all_queue_weights and an +// AdjustedFairShare by resharing any unused capacity (as determined by a queue's demand) +func (sctx *SchedulingContext) UpdateFairShares() { + const maxIterations = 5 + + type queueInfo struct { + queueName string + adjustedShare float64 + fairShare float64 + weight float64 + cappedShare float64 + } + + queueInfos := make([]*queueInfo, 0, len(sctx.QueueSchedulingContexts)) + for queueName, qctx := range sctx.QueueSchedulingContexts { + cappedShare := 1.0 + if !sctx.TotalResources.IsZero() { + cappedShare = sctx.FairnessCostProvider.CostFromAllocationAndWeight(qctx.Demand, qctx.Weight) * qctx.Weight + } + queueInfos = append(queueInfos, &queueInfo{ + queueName: queueName, + adjustedShare: 0, + fairShare: qctx.Weight / sctx.WeightSum, + weight: qctx.Weight, + cappedShare: cappedShare, + }) + } + + // We do this so that we get deterministic output + slices.SortFunc(queueInfos, func(a, b *queueInfo) int { + return strings.Compare(a.queueName, b.queueName) + }) + + unallocated := 1.0 // this is the proportion of the cluster that we can share each time + + // We will reshare unused capacity until we've reshared 99% of all capacity or we've completed 5 iteration + for i := 0; i < maxIterations && unallocated > 0.01; i++ { + totalWeight := 0.0 + for _, q := range queueInfos { + totalWeight += q.weight + } + + for _, q := range queueInfos { + if q.weight > 0 { + share := (q.weight / totalWeight) * unallocated + q.adjustedShare += share + } + } + unallocated = 0.0 + for _, q := range queueInfos { + excessShare := q.adjustedShare - q.cappedShare + if excessShare > 0 { + q.adjustedShare = q.cappedShare + q.weight = 0.0 + unallocated += excessShare + } + } + } + + for _, q := range queueInfos { + qtx := sctx.QueueSchedulingContexts[q.queueName] + qtx.FairShare = q.fairShare + qtx.AdjustedFairShare = q.adjustedShare + } +} + func (sctx *SchedulingContext) ReportString(verbosity int32) string { var sb strings.Builder w := tabwriter.NewWriter(&sb, 1, 1, 1, ' ', 0) @@ -342,6 +412,13 @@ type QueueSchedulingContext struct { // Total resources assigned to the queue across all clusters by priority class priority. // Includes jobs scheduled during this invocation of the scheduler. Allocated schedulerobjects.ResourceList + // Total demand from this queue. This is essentially the cumulative resources of all non-terminal jobs at the + // start of the scheduling cycle + Demand schedulerobjects.ResourceList + // Fair share is the weight of this queue over the sum of the weights of all queues + FairShare float64 + // AdjustedFairShare modifies fair share such that queues that have a demand cost less than their fair share, have their fair share reallocated. + AdjustedFairShare float64 // Total resources assigned to the queue across all clusters by priority class. // Includes jobs scheduled during this invocation of the scheduler. AllocatedByPriorityClass schedulerobjects.QuantityByTAndResourceType[string] @@ -623,6 +700,9 @@ type JobSchedulingContext struct { // GangInfo holds all the information that is necessary to schedule a gang, // such as the lower and upper bounds on its size. GangInfo + // This is the node the pod is assigned to. + // This is only set for evicted jobs and is set alongside adding an additionalNodeSelector for the node + AssignedNodeId string } func (jctx *JobSchedulingContext) String() string { @@ -663,6 +743,17 @@ func (jctx *JobSchedulingContext) Fail(unschedulableReason string) { } } +func (jctx *JobSchedulingContext) GetAssignedNodeId() string { + return jctx.AssignedNodeId +} + +func (jctx *JobSchedulingContext) SetAssignedNodeId(assignedNodeId string) { + if assignedNodeId != "" { + jctx.AssignedNodeId = assignedNodeId + jctx.AddNodeSelector(schedulerconfig.NodeIdLabel, assignedNodeId) + } +} + func (jctx *JobSchedulingContext) AddNodeSelector(key, value string) { if jctx.AdditionalNodeSelectors == nil { jctx.AdditionalNodeSelectors = map[string]string{key: value} @@ -671,15 +762,6 @@ func (jctx *JobSchedulingContext) AddNodeSelector(key, value string) { } } -func (jctx *JobSchedulingContext) GetNodeSelector(key string) (string, bool) { - if value, ok := jctx.AdditionalNodeSelectors[key]; ok { - return value, true - } else if value, ok := jctx.PodRequirements.NodeSelector[key]; ok { - return value, true - } - return "", false -} - type GangInfo struct { Id string Cardinality int diff --git a/internal/scheduler/context/context_test.go b/internal/scheduler/context/context_test.go index e538f966986..dcc9aeedcbc 100644 --- a/internal/scheduler/context/context_test.go +++ b/internal/scheduler/context/context_test.go @@ -9,6 +9,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" armadaslices "github.com/armadaproject/armada/internal/common/slices" + "github.com/armadaproject/armada/internal/scheduler/configuration" "github.com/armadaproject/armada/internal/scheduler/fairness" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/internal/scheduler/testfixtures" @@ -53,7 +54,7 @@ func TestSchedulingContextAccounting(t *testing.T) { }, } for _, queue := range []string{"A", "B"} { - err := sctx.AddQueueSchedulingContext(queue, priorityFactorByQueue[queue], allocatedByQueueAndPriorityClass[queue], nil) + err := sctx.AddQueueSchedulingContext(queue, priorityFactorByQueue[queue], allocatedByQueueAndPriorityClass[queue], schedulerobjects.ResourceList{}, nil) require.NoError(t, err) } @@ -96,3 +97,174 @@ func testSmallCpuJobSchedulingContext(queue, priorityClassName string) *JobSched GangInfo: EmptyGangInfo(job), } } + +func TestJobSchedulingContext_SetAssignedNodeId(t *testing.T) { + jctx := &JobSchedulingContext{} + + assert.Equal(t, "", jctx.GetAssignedNodeId()) + assert.Empty(t, jctx.AdditionalNodeSelectors) + + // Will not add a node selector if input is empty + jctx.SetAssignedNodeId("") + assert.Equal(t, "", jctx.GetAssignedNodeId()) + assert.Empty(t, jctx.AdditionalNodeSelectors) + + jctx.SetAssignedNodeId("node1") + assert.Equal(t, "node1", jctx.GetAssignedNodeId()) + assert.Len(t, jctx.AdditionalNodeSelectors, 1) + assert.Equal(t, map[string]string{configuration.NodeIdLabel: "node1"}, jctx.AdditionalNodeSelectors) +} + +func TestCalculateFairShares(t *testing.T) { + zeroCpu := schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse("0")}, + } + oneCpu := schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse("1")}, + } + fortyCpu := schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse("40")}, + } + oneHundredCpu := schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse("100")}, + } + oneThousandCpu := schedulerobjects.ResourceList{ + Resources: map[string]resource.Quantity{"cpu": resource.MustParse("1000")}, + } + tests := map[string]struct { + availableResources schedulerobjects.ResourceList + queueCtxs map[string]*QueueSchedulingContext + expectedFairShares map[string]float64 + expectedAdjustedFairShares map[string]float64 + }{ + "one queue, demand exceeds capacity": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneThousandCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 1.0}, + expectedAdjustedFairShares: map[string]float64{"queueA": 1.0}, + }, + "one queue, demand less than capacity": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 1.0}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.01}, + }, + "two queues, equal weights, demand exceeds capacity": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneThousandCpu}, + "queueB": {Weight: 1.0, Demand: oneThousandCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 0.5, "queueB": 0.5}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.5, "queueB": 0.5}, + }, + "two queues, equal weights, demand less than capacity for both queues": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneCpu}, + "queueB": {Weight: 1.0, Demand: oneCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 0.5, "queueB": 0.5}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.01, "queueB": 0.01}, + }, + "two queues, equal weights, demand less than capacity for one queue": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneCpu}, + "queueB": {Weight: 1.0, Demand: oneThousandCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 0.5, "queueB": 0.5}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.01, "queueB": 0.99}, + }, + "two queues, non equal weights, demand exceeds capacity for both queues": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneThousandCpu}, + "queueB": {Weight: 3.0, Demand: oneThousandCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 0.25, "queueB": 0.75}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.25, "queueB": 0.75}, + }, + "two queues, non equal weights, demand exceeds capacity for higher priority queue only": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneCpu}, + "queueB": {Weight: 3.0, Demand: oneThousandCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 0.25, "queueB": 0.75}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.01, "queueB": 0.99}, + }, + "two queues, non equal weights, demand exceeds capacity for lower priority queue only": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneThousandCpu}, + "queueB": {Weight: 3.0, Demand: oneCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 0.25, "queueB": 0.75}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.99, "queueB": 0.01}, + }, + "three queues, equal weights. Adjusted fair share requires multiple iterations": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneCpu}, + "queueB": {Weight: 1.0, Demand: fortyCpu}, + "queueC": {Weight: 1.0, Demand: oneThousandCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 1.0 / 3, "queueB": 1.0 / 3, "queueC": 1.0 / 3}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.01, "queueB": 0.4, "queueC": 0.59}, + }, + "No demand": { + availableResources: oneHundredCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: zeroCpu}, + "queueB": {Weight: 1.0, Demand: zeroCpu}, + "queueC": {Weight: 1.0, Demand: zeroCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 1.0 / 3, "queueB": 1.0 / 3, "queueC": 1.0 / 3}, + expectedAdjustedFairShares: map[string]float64{"queueA": 0.0, "queueB": 0.0, "queueC": 0.0}, + }, + "No capacity": { + availableResources: zeroCpu, + queueCtxs: map[string]*QueueSchedulingContext{ + "queueA": {Weight: 1.0, Demand: oneCpu}, + "queueB": {Weight: 1.0, Demand: oneCpu}, + "queueC": {Weight: 1.0, Demand: oneCpu}, + }, + expectedFairShares: map[string]float64{"queueA": 1.0 / 3, "queueB": 1.0 / 3, "queueC": 1.0 / 3}, + expectedAdjustedFairShares: map[string]float64{"queueA": 1.0 / 3, "queueB": 1.0 / 3, "queueC": 1.0 / 3}, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + fairnessCostProvider, err := fairness.NewDominantResourceFairness(tc.availableResources, []string{"cpu"}) + require.NoError(t, err) + sctx := NewSchedulingContext( + "executor", + "pool", + testfixtures.TestPriorityClasses, + testfixtures.TestDefaultPriorityClass, + fairnessCostProvider, + nil, + tc.availableResources, + ) + for qName, q := range tc.queueCtxs { + err = sctx.AddQueueSchedulingContext( + qName, q.Weight, schedulerobjects.QuantityByTAndResourceType[string]{}, q.Demand, nil) + require.NoError(t, err) + } + sctx.UpdateFairShares() + for qName, qctx := range sctx.QueueSchedulingContexts { + expectedFairShare, ok := tc.expectedFairShares[qName] + require.True(t, ok, "Expected fair share for queue %s not found", qName) + expectedAdjustedFairShare, ok := tc.expectedAdjustedFairShares[qName] + require.True(t, ok, "Expected adjusted fair share for queue %s not found", qName) + assert.Equal(t, expectedFairShare, qctx.FairShare, "Fair share for queue %s", qName) + assert.Equal(t, expectedAdjustedFairShare, qctx.AdjustedFairShare, "Adjusted Fair share for queue %s", qName) + } + }) + } +} diff --git a/internal/scheduler/gang_scheduler_test.go b/internal/scheduler/gang_scheduler_test.go index 67c2086baae..09ce9fc04dd 100644 --- a/internal/scheduler/gang_scheduler_test.go +++ b/internal/scheduler/gang_scheduler_test.go @@ -556,6 +556,7 @@ func TestGangScheduler(t *testing.T) { queue, priorityFactor, nil, + schedulerobjects.NewResourceList(0), rate.NewLimiter( rate.Limit(tc.SchedulingConfig.MaximumPerQueueSchedulingRate), tc.SchedulingConfig.MaximumPerQueueSchedulingBurst, diff --git a/internal/scheduler/nodedb/nodedb.go b/internal/scheduler/nodedb/nodedb.go index 1259c930933..d900f56d03f 100644 --- a/internal/scheduler/nodedb/nodedb.go +++ b/internal/scheduler/nodedb/nodedb.go @@ -544,7 +544,7 @@ func (nodeDb *NodeDb) SelectNodeForJobWithTxn(txn *memdb.Txn, jctx *schedulercon }() // If the nodeIdLabel selector is set, consider only that node. - if nodeId, ok := jctx.GetNodeSelector(configuration.NodeIdLabel); ok { + if nodeId := jctx.GetAssignedNodeId(); nodeId != "" { if it, err := txn.Get("nodes", "id", nodeId); err != nil { return nil, errors.WithStack(err) } else { @@ -808,11 +808,17 @@ func (nodeDb *NodeDb) selectNodeForPodWithItAtPriority( // It does this by considering all evicted jobs in the reverse order they would be scheduled in and preventing // from being re-scheduled the jobs that would be scheduled last. func (nodeDb *NodeDb) selectNodeForJobWithFairPreemption(txn *memdb.Txn, jctx *schedulercontext.JobSchedulingContext) (*internaltypes.Node, error) { + type consideredNode struct { + node *internaltypes.Node + availableResource internaltypes.ResourceList + evictedJobs []*EvictedJobSchedulingContext + staticRequirementsNotMet bool + } + pctx := jctx.PodSchedulingContext var selectedNode *internaltypes.Node - nodesById := make(map[string]*internaltypes.Node) - evictedJobSchedulingContextsByNodeId := make(map[string][]*EvictedJobSchedulingContext) + nodesById := make(map[string]*consideredNode) it, err := txn.ReverseLowerBound("evictedJobs", "index", math.MaxInt) if err != nil { return nil, errors.WithStack(err) @@ -821,59 +827,74 @@ func (nodeDb *NodeDb) selectNodeForJobWithFairPreemption(txn *memdb.Txn, jctx *s for obj := it.Next(); obj != nil && selectedNode == nil; obj = it.Next() { evictedJobSchedulingContext := obj.(*EvictedJobSchedulingContext) evictedJctx := evictedJobSchedulingContext.JobSchedulingContext - nodeId, ok := evictedJctx.GetNodeSelector(configuration.NodeIdLabel) - if !ok { - return nil, errors.Errorf("evicted job %s does not have a nodeIdLabel", evictedJctx.JobId) + nodeId := evictedJctx.GetAssignedNodeId() + if nodeId == "" { + return nil, errors.Errorf("evicted job %s does not have an assigned nodeId", evictedJctx.JobId) } node, ok := nodesById[nodeId] if !ok { - node, err = nodeDb.GetNodeWithTxn(txn, nodeId) + nodeFromDb, err := nodeDb.GetNodeWithTxn(txn, nodeId) if err != nil { return nil, errors.WithStack(err) } - node = node.UnsafeCopy() + node = &consideredNode{ + node: nodeFromDb, + availableResource: nodeFromDb.AllocatableByPriority[evictedPriority], + staticRequirementsNotMet: false, + evictedJobs: []*EvictedJobSchedulingContext{}, + } + nodesById[nodeId] = node } - err = nodeDb.unbindJobFromNodeInPlace(nodeDb.priorityClasses, evictedJctx.Job, node) - if err != nil { - return nil, err + if node.staticRequirementsNotMet { + continue } - evictedJobSchedulingContextsByNodeId[nodeId] = append(evictedJobSchedulingContextsByNodeId[nodeId], evictedJobSchedulingContext) - priority, ok := nodeDb.GetScheduledAtPriority(evictedJctx.JobId) - if !ok { - priority = evictedJctx.PodRequirements.Priority - } - if priority > maxPriority { - maxPriority = priority + // Evict job, update available resource + node.availableResource = node.availableResource.Add(evictedJctx.ResourceRequirements) + node.evictedJobs = append(node.evictedJobs, evictedJobSchedulingContext) + + dynamicRequirementsMet, _ := DynamicJobRequirementsMet(node.availableResource, jctx) + if !dynamicRequirementsMet { + continue } - matches, reason, err := JobRequirementsMet( - node, - // At this point, we've unbound the jobs running on the node. - // Hence, we should check if the job is schedulable at evictedPriority, - // since that indicates the job can be scheduled without causing further preemptions. - evictedPriority, - jctx, - ) + + staticRequirementsMet, reason, err := StaticJobRequirementsMet(node.node, jctx) if err != nil { return nil, err } - if matches { - selectedNode = node - } else { + if !staticRequirementsMet { + node.staticRequirementsNotMet = true s := nodeDb.stringFromPodRequirementsNotMetReason(reason) pctx.NumExcludedNodesByReason[s] += 1 + continue } - } - if selectedNode != nil { - pctx.NodeId = selectedNode.GetId() - pctx.PreemptedAtPriority = maxPriority - for _, evictedJobSchedulingContext := range evictedJobSchedulingContextsByNodeId[selectedNode.GetId()] { - if err := txn.Delete("evictedJobs", evictedJobSchedulingContext); err != nil { + + nodeCopy := node.node.UnsafeCopy() + for _, job := range node.evictedJobs { + // Remove preempted job from node + err = nodeDb.unbindJobFromNodeInPlace(nodeDb.priorityClasses, job.JobSchedulingContext.Job, nodeCopy) + if err != nil { + return nil, err + } + // Remove preempted job from list of evicted jobs + if err := txn.Delete("evictedJobs", job); err != nil { return nil, errors.WithStack(err) } + + priority, ok := nodeDb.GetScheduledAtPriority(evictedJctx.JobId) + if !ok { + priority = evictedJctx.PodRequirements.Priority + } + if priority > maxPriority { + maxPriority = priority + } } + + selectedNode = nodeCopy + pctx.NodeId = selectedNode.GetId() + pctx.PreemptedAtPriority = maxPriority } return selectedNode, nil } diff --git a/internal/scheduler/nodedb/nodedb_test.go b/internal/scheduler/nodedb/nodedb_test.go index b996cde278e..5c98c3bc770 100644 --- a/internal/scheduler/nodedb/nodedb_test.go +++ b/internal/scheduler/nodedb/nodedb_test.go @@ -69,13 +69,11 @@ func TestSelectNodeForPod_NodeIdLabel_Success(t *testing.T) { require.NotEmpty(t, nodeId) db, err := newNodeDbWithNodes(nodes) require.NoError(t, err) - jobs := testfixtures.WithNodeSelectorJobs( - map[string]string{schedulerconfig.NodeIdLabel: nodeId}, - testfixtures.N1Cpu4GiJobs("A", testfixtures.PriorityClass0, 1), - ) + jobs := testfixtures.N1Cpu4GiJobs("A", testfixtures.PriorityClass0, 1) jctxs := schedulercontext.JobSchedulingContextsFromJobs(testfixtures.TestPriorityClasses, jobs) for _, jctx := range jctxs { txn := db.Txn(false) + jctx.SetAssignedNodeId(nodeId) node, err := db.SelectNodeForJobWithTxn(txn, jctx) txn.Abort() require.NoError(t, err) @@ -96,13 +94,11 @@ func TestSelectNodeForPod_NodeIdLabel_Failure(t *testing.T) { require.NotEmpty(t, nodeId) db, err := newNodeDbWithNodes(nodes) require.NoError(t, err) - jobs := testfixtures.WithNodeSelectorJobs( - map[string]string{schedulerconfig.NodeIdLabel: "this node does not exist"}, - testfixtures.N1Cpu4GiJobs("A", testfixtures.PriorityClass0, 1), - ) + jobs := testfixtures.N1Cpu4GiJobs("A", testfixtures.PriorityClass0, 1) jctxs := schedulercontext.JobSchedulingContextsFromJobs(testfixtures.TestPriorityClasses, jobs) for _, jctx := range jctxs { txn := db.Txn(false) + jctx.SetAssignedNodeId("non-existent node") node, err := db.SelectNodeForJobWithTxn(txn, jctx) txn.Abort() if !assert.NoError(t, err) { diff --git a/internal/scheduler/preempting_queue_scheduler.go b/internal/scheduler/preempting_queue_scheduler.go index 75c972ca8c6..dc6c8c069df 100644 --- a/internal/scheduler/preempting_queue_scheduler.go +++ b/internal/scheduler/preempting_queue_scheduler.go @@ -2,6 +2,7 @@ package scheduler import ( "fmt" + "math" "reflect" "time" @@ -14,7 +15,6 @@ import ( armadamaps "github.com/armadaproject/armada/internal/common/maps" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/types" - schedulerconfig "github.com/armadaproject/armada/internal/scheduler/configuration" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" "github.com/armadaproject/armada/internal/scheduler/fairness" @@ -27,11 +27,12 @@ import ( // PreemptingQueueScheduler is a scheduler that makes a unified decisions on which jobs to preempt and schedule. // Uses QueueScheduler as a building block. type PreemptingQueueScheduler struct { - schedulingContext *schedulercontext.SchedulingContext - constraints schedulerconstraints.SchedulingConstraints - protectedFractionOfFairShare float64 - jobRepo JobRepository - nodeDb *nodedb.NodeDb + schedulingContext *schedulercontext.SchedulingContext + constraints schedulerconstraints.SchedulingConstraints + protectedFractionOfFairShare float64 + useAdjustedFairShareProtection bool + jobRepo JobRepository + nodeDb *nodedb.NodeDb // Maps job ids to the id of the node the job is associated with. // For scheduled or running jobs, that is the node the job is assigned to. // For preempted jobs, that is the node the job was preempted from. @@ -50,6 +51,7 @@ func NewPreemptingQueueScheduler( sctx *schedulercontext.SchedulingContext, constraints schedulerconstraints.SchedulingConstraints, protectedFractionOfFairShare float64, + useAdjustedFairShareProtection bool, jobRepo JobRepository, nodeDb *nodedb.NodeDb, initialNodeIdByJobId map[string]string, @@ -70,14 +72,15 @@ func NewPreemptingQueueScheduler( initialJobIdsByGangId[gangId] = maps.Clone(jobIds) } return &PreemptingQueueScheduler{ - schedulingContext: sctx, - constraints: constraints, - protectedFractionOfFairShare: protectedFractionOfFairShare, - jobRepo: jobRepo, - nodeDb: nodeDb, - nodeIdByJobId: maps.Clone(initialNodeIdByJobId), - jobIdsByGangId: initialJobIdsByGangId, - gangIdByJobId: maps.Clone(initialGangIdByJobId), + schedulingContext: sctx, + constraints: constraints, + protectedFractionOfFairShare: protectedFractionOfFairShare, + useAdjustedFairShareProtection: useAdjustedFairShareProtection, + jobRepo: jobRepo, + nodeDb: nodeDb, + nodeIdByJobId: maps.Clone(initialNodeIdByJobId), + jobIdsByGangId: initialJobIdsByGangId, + gangIdByJobId: maps.Clone(initialGangIdByJobId), } } @@ -128,8 +131,11 @@ func (sch *PreemptingQueueScheduler) Schedule(ctx *armadacontext.Context) (*Sche return false } if qctx, ok := sch.schedulingContext.QueueSchedulingContexts[job.Queue()]; ok { - fairShare := qctx.Weight / sch.schedulingContext.WeightSum actualShare := sch.schedulingContext.FairnessCostProvider.CostFromQueue(qctx) / totalCost + fairShare := qctx.FairShare + if sch.useAdjustedFairShareProtection { + fairShare = math.Max(qctx.AdjustedFairShare, fairShare) + } fractionOfFairShare := actualShare / fairShare if fractionOfFairShare <= sch.protectedFractionOfFairShare { return false @@ -436,7 +442,7 @@ func (sch *PreemptingQueueScheduler) evictionAssertions(evictorResult *EvictorRe if !jctx.IsEvicted { return errors.New("evicted job %s is not marked as such") } - if nodeId, ok := jctx.GetNodeSelector(schedulerconfig.NodeIdLabel); ok { + if nodeId := jctx.GetAssignedNodeId(); nodeId != "" { if _, ok := evictorResult.AffectedNodesById[nodeId]; !ok { return errors.Errorf("node id %s targeted by job %s is not marked as affected", nodeId, jobId) } @@ -858,7 +864,7 @@ func (evi *Evictor) Evict(ctx *armadacontext.Context, nodeDbTxn *memdb.Txn) (*Ev // TODO(albin): We can remove the checkOnlyDynamicRequirements flag in the nodeDb now that we've added the tolerations. jctx := schedulercontext.JobSchedulingContextFromJob(job) jctx.IsEvicted = true - jctx.AddNodeSelector(schedulerconfig.NodeIdLabel, node.GetId()) + jctx.SetAssignedNodeId(node.GetId()) evictedJctxsByJobId[job.Id()] = jctx jctx.AdditionalTolerations = append(jctx.AdditionalTolerations, node.GetTolerationsForTaints()...) diff --git a/internal/scheduler/preempting_queue_scheduler_test.go b/internal/scheduler/preempting_queue_scheduler_test.go index 93000154e43..e726c4b4a88 100644 --- a/internal/scheduler/preempting_queue_scheduler_test.go +++ b/internal/scheduler/preempting_queue_scheduler_test.go @@ -1276,6 +1276,42 @@ func TestPreemptingQueueScheduler(t *testing.T) { "C": 1, }, }, + "ProtectedFractionOfFairShare reshared": { + SchedulingConfig: testfixtures.WithProtectedFractionOfFairShareConfig( + 1.0, + testfixtures.TestSchedulingConfig(), + ), + Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), + Rounds: []SchedulingRound{ + { + JobsByQueue: map[string][]*jobdb.Job{ + "A": testfixtures.N1Cpu4GiJobs("A", testfixtures.PriorityClass2NonPreemptible, 16), // not preemptible + "B": testfixtures.N1Cpu4GiJobs("B", testfixtures.PriorityClass0, 11), + "C": testfixtures.N1Cpu4GiJobs("C", testfixtures.PriorityClass0, 3), + "D": testfixtures.N1Cpu4GiJobs("D", testfixtures.PriorityClass0, 2), + }, + ExpectedScheduledIndices: map[string][]int{ + "A": testfixtures.IntRange(0, 15), + "B": testfixtures.IntRange(0, 10), + "C": testfixtures.IntRange(0, 2), + "D": testfixtures.IntRange(0, 1), + }, + }, + { + // D submits one more job. No preemption occurs because B is below adjusted fair share + JobsByQueue: map[string][]*jobdb.Job{ + "D": testfixtures.N1Cpu4GiJobs("D", testfixtures.PriorityClass0, 1), + }, + }, + {}, // Empty round to make sure nothing changes. + }, + PriorityFactorByQueue: map[string]float64{ + "A": 1, + "B": 1, + "C": 1, + "D": 1, + }, + }, "DominantResourceFairness": { SchedulingConfig: testfixtures.TestSchedulingConfig(), Nodes: testfixtures.N32CpuNodes(1, testfixtures.TestPriorities), @@ -1697,6 +1733,8 @@ func TestPreemptingQueueScheduler(t *testing.T) { ) } + demandByQueue := map[string]schedulerobjects.ResourceList{} + // Run the scheduler. ctx := armadacontext.Background() for i, round := range tc.Rounds { @@ -1712,6 +1750,12 @@ func TestPreemptingQueueScheduler(t *testing.T) { queuedJobs = append(queuedJobs, job.WithQueued(true)) roundByJobId[job.Id()] = i indexByJobId[job.Id()] = j + r, ok := demandByQueue[job.Queue()] + if !ok { + r = schedulerobjects.NewResourceList(len(job.PodRequirements().ResourceRequirements.Requests)) + demandByQueue[job.Queue()] = r + } + r.AddV1ResourceList(job.PodRequirements().ResourceRequirements.Requests) } } err = jobDbTxn.Upsert(queuedJobs) @@ -1733,6 +1777,12 @@ func TestPreemptingQueueScheduler(t *testing.T) { delete(gangIdByJobId, job.Id()) delete(jobIdsByGangId[gangId], job.Id()) } + r, ok := demandByQueue[job.Queue()] + if !ok { + r = schedulerobjects.NewResourceList(len(job.PodRequirements().ResourceRequirements.Requests)) + demandByQueue[job.Queue()] = r + } + r.SubV1ResourceList(job.PodRequirements().ResourceRequirements.Requests) } } } @@ -1774,6 +1824,7 @@ func TestPreemptingQueueScheduler(t *testing.T) { queue, weight, allocatedByQueueAndPriorityClass[queue], + demandByQueue[queue], limiterByQueue[queue], ) require.NoError(t, err) @@ -1785,10 +1836,12 @@ func TestPreemptingQueueScheduler(t *testing.T) { tc.SchedulingConfig, nil, ) + sctx.UpdateFairShares() sch := NewPreemptingQueueScheduler( sctx, constraints, tc.SchedulingConfig.ProtectedFractionOfFairShare, + tc.SchedulingConfig.UseAdjustedFairShareProtection, NewSchedulerJobRepositoryAdapter(jobDbTxn), nodeDb, nodeIdByJobId, @@ -2130,7 +2183,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { ) for queue, priorityFactor := range priorityFactorByQueue { weight := 1 / priorityFactor - err := sctx.AddQueueSchedulingContext(queue, weight, make(schedulerobjects.QuantityByTAndResourceType[string]), limiterByQueue[queue]) + err := sctx.AddQueueSchedulingContext(queue, weight, make(schedulerobjects.QuantityByTAndResourceType[string]), schedulerobjects.NewResourceList(0), limiterByQueue[queue]) require.NoError(b, err) } constraints := schedulerconstraints.NewSchedulingConstraints( @@ -2144,6 +2197,7 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { sctx, constraints, tc.SchedulingConfig.ProtectedFractionOfFairShare, + tc.SchedulingConfig.UseAdjustedFairShareProtection, NewSchedulerJobRepositoryAdapter(jobDbTxn), nodeDb, nil, @@ -2197,13 +2251,14 @@ func BenchmarkPreemptingQueueScheduler(b *testing.B) { ) for queue, priorityFactor := range priorityFactorByQueue { weight := 1 / priorityFactor - err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByQueueAndPriorityClass[queue], limiterByQueue[queue]) + err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByQueueAndPriorityClass[queue], schedulerobjects.NewResourceList(0), limiterByQueue[queue]) require.NoError(b, err) } sch := NewPreemptingQueueScheduler( sctx, constraints, tc.SchedulingConfig.ProtectedFractionOfFairShare, + tc.SchedulingConfig.UseAdjustedFairShareProtection, NewSchedulerJobRepositoryAdapter(jobDbTxn), nodeDb, nil, diff --git a/internal/scheduler/publisher.go b/internal/scheduler/publisher.go index caf6716b59b..104d96f495a 100644 --- a/internal/scheduler/publisher.go +++ b/internal/scheduler/publisher.go @@ -43,6 +43,8 @@ type PulsarPublisher struct { numPartitions int // Timeout after which async messages sends will be considered failed pulsarSendTimeout time.Duration + // Maximum number of Events in each EventSequence + maxEventsPerMessage int // Maximum size (in bytes) of produced pulsar messages. // This must be below 4MB which is the pulsar message size limit maxMessageBatchSize uint @@ -51,6 +53,7 @@ type PulsarPublisher struct { func NewPulsarPublisher( pulsarClient pulsar.Client, producerOptions pulsar.ProducerOptions, + maxEventsPerMessage int, pulsarSendTimeout time.Duration, ) (*PulsarPublisher, error) { partitions, err := pulsarClient.TopicPartitions(producerOptions.Topic) @@ -69,6 +72,7 @@ func NewPulsarPublisher( return &PulsarPublisher{ producer: producer, pulsarSendTimeout: pulsarSendTimeout, + maxEventsPerMessage: maxEventsPerMessage, maxMessageBatchSize: maxMessageBatchSize, numPartitions: len(partitions), }, nil @@ -78,6 +82,7 @@ func NewPulsarPublisher( // single event sequences up to maxMessageBatchSize. func (p *PulsarPublisher) PublishMessages(ctx *armadacontext.Context, events []*armadaevents.EventSequence, shouldPublish func() bool) error { sequences := eventutil.CompactEventSequences(events) + sequences = eventutil.LimitSequencesEventMessageCount(sequences, p.maxEventsPerMessage) sequences, err := eventutil.LimitSequencesByteSize(sequences, p.maxMessageBatchSize, true) if err != nil { return err diff --git a/internal/scheduler/publisher_test.go b/internal/scheduler/publisher_test.go index 6e4e693dcf5..c3fa778565c 100644 --- a/internal/scheduler/publisher_test.go +++ b/internal/scheduler/publisher_test.go @@ -120,7 +120,7 @@ func TestPulsarPublisher_TestPublish(t *testing.T) { }).AnyTimes() options := pulsar.ProducerOptions{Topic: topic} - publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second) + publisher, err := NewPulsarPublisher(mockPulsarClient, options, 1000, 5*time.Second) require.NoError(t, err) err = publisher.PublishMessages(ctx, tc.eventSequences, func() bool { return tc.amLeader }) @@ -191,7 +191,7 @@ func TestPulsarPublisher_TestPublishMarkers(t *testing.T) { options := pulsar.ProducerOptions{Topic: topic} ctx := armadacontext.TODO() - publisher, err := NewPulsarPublisher(mockPulsarClient, options, 5*time.Second) + publisher, err := NewPulsarPublisher(mockPulsarClient, options, 1000, 5*time.Second) require.NoError(t, err) published, err := publisher.PublishMarkers(ctx, uuid.New()) diff --git a/internal/scheduler/queue_scheduler_test.go b/internal/scheduler/queue_scheduler_test.go index 5f9493e041b..510b586f5de 100644 --- a/internal/scheduler/queue_scheduler_test.go +++ b/internal/scheduler/queue_scheduler_test.go @@ -569,6 +569,7 @@ func TestQueueScheduler(t *testing.T) { err := sctx.AddQueueSchedulingContext( q.Name, weight, tc.InitialAllocatedByQueueAndPriorityClass[q.Name], + schedulerobjects.NewResourceList(0), rate.NewLimiter( rate.Limit(tc.SchedulingConfig.MaximumPerQueueSchedulingRate), tc.SchedulingConfig.MaximumPerQueueSchedulingBurst, diff --git a/internal/scheduler/scheduler_metrics.go b/internal/scheduler/scheduler_metrics.go index bc81d4c92c2..04464bb7ac4 100644 --- a/internal/scheduler/scheduler_metrics.go +++ b/internal/scheduler/scheduler_metrics.go @@ -60,6 +60,15 @@ var fairSharePerQueueDesc = prometheus.NewDesc( }, nil, ) +var adjustedFairSharePerQueueDesc = prometheus.NewDesc( + fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "adjusted_fair_share"), + "Adjusted Fair share of each queue and pool.", + []string{ + "queue", + "pool", + }, nil, +) + var actualSharePerQueueDesc = prometheus.NewDesc( fmt.Sprintf("%s_%s_%s", NAMESPACE, SUBSYSTEM, "actual_share"), "Actual share of each queue and pool.", @@ -147,6 +156,7 @@ func generateSchedulerMetrics(schedulingRoundData schedulingRoundData) []prometh for key, value := range schedulingRoundData.queuePoolData { result = append(result, prometheus.MustNewConstMetric(consideredJobsDesc, prometheus.GaugeValue, float64(value.numberOfJobsConsidered), key.queue, key.pool)) result = append(result, prometheus.MustNewConstMetric(fairSharePerQueueDesc, prometheus.GaugeValue, float64(value.fairShare), key.queue, key.pool)) + result = append(result, prometheus.MustNewConstMetric(adjustedFairSharePerQueueDesc, prometheus.GaugeValue, float64(value.adjustedFairShare), key.queue, key.pool)) result = append(result, prometheus.MustNewConstMetric(actualSharePerQueueDesc, prometheus.GaugeValue, float64(value.actualShare), key.queue, key.pool)) } for key, value := range schedulingRoundData.scheduledJobData { @@ -185,17 +195,15 @@ func (metrics *SchedulerMetrics) calculateQueuePoolMetrics(schedulingContexts [] result := make(map[queuePoolKey]queuePoolData) for _, schedContext := range schedulingContexts { totalCost := schedContext.TotalCost() - totalWeight := schedContext.WeightSum pool := schedContext.Pool for queue, queueContext := range schedContext.QueueSchedulingContexts { key := queuePoolKey{queue: queue, pool: pool} - fairShare := queueContext.Weight / totalWeight actualShare := schedContext.FairnessCostProvider.CostFromQueue(queueContext) / totalCost - result[key] = queuePoolData{ numberOfJobsConsidered: len(queueContext.UnsuccessfulJobSchedulingContexts) + len(queueContext.SuccessfulJobSchedulingContexts), - fairShare: fairShare, + fairShare: queueContext.FairShare, + adjustedFairShare: queueContext.AdjustedFairShare, actualShare: actualShare, } } @@ -224,4 +232,5 @@ type queuePoolData struct { numberOfJobsConsidered int actualShare float64 fairShare float64 + adjustedFairShare float64 } diff --git a/internal/scheduler/schedulerapp.go b/internal/scheduler/schedulerapp.go index 801aa1c596c..130a36d9af4 100644 --- a/internal/scheduler/schedulerapp.go +++ b/internal/scheduler/schedulerapp.go @@ -135,7 +135,7 @@ func Run(config schedulerconfig.Configuration) error { CompressionLevel: config.Pulsar.CompressionLevel, BatchingMaxSize: config.Pulsar.MaxAllowedMessageSize, Topic: config.Pulsar.JobsetEventsTopic, - }, config.PulsarSendTimeout) + }, config.Pulsar.MaxAllowedEventsPerMessage, config.PulsarSendTimeout) if err != nil { return errors.WithMessage(err, "error creating pulsar publisher") } @@ -182,6 +182,7 @@ func Run(config schedulerconfig.Configuration) error { config.Scheduling.NodeIdLabel, config.Scheduling.PriorityClassNameOverride, config.Scheduling.PriorityClasses, + config.Pulsar.MaxAllowedEventsPerMessage, config.Pulsar.MaxAllowedMessageSize, ) if err != nil { diff --git a/internal/scheduler/scheduling_algo.go b/internal/scheduler/scheduling_algo.go index 39f6d649404..0af6e58c435 100644 --- a/internal/scheduler/scheduling_algo.go +++ b/internal/scheduler/scheduling_algo.go @@ -2,7 +2,6 @@ package scheduler import ( "context" - "math/rand" "sort" "strings" "time" @@ -19,7 +18,6 @@ import ( "github.com/armadaproject/armada/internal/common/logging" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/common/stringinterner" - "github.com/armadaproject/armada/internal/common/util" "github.com/armadaproject/armada/internal/scheduler/configuration" schedulerconstraints "github.com/armadaproject/armada/internal/scheduler/constraints" schedulercontext "github.com/armadaproject/armada/internal/scheduler/context" @@ -64,8 +62,6 @@ type FairSchedulingAlgo struct { queueQuarantiner *quarantine.QueueQuarantiner // Function that is called every time an executor is scheduled. Useful for testing. onExecutorScheduled func(executor *schedulerobjects.Executor) - // rand and clock injected here for repeatable testing. - rand *rand.Rand clock clock.Clock stringInterner *stringinterner.StringInterner resourceListFactory *internaltypes.ResourceListFactory @@ -99,7 +95,6 @@ func NewFairSchedulingAlgo( nodeQuarantiner: nodeQuarantiner, queueQuarantiner: queueQuarantiner, onExecutorScheduled: func(executor *schedulerobjects.Executor) {}, - rand: util.NewThreadsafeRand(time.Now().UnixNano()), clock: clock.RealClock{}, stringInterner: stringInterner, resourceListFactory: resourceListFactory, @@ -253,7 +248,7 @@ func (it *JobQueueIteratorAdapter) Next() (*jobdb.Job, error) { type fairSchedulingAlgoContext struct { queues []*api.Queue priorityFactorByQueue map[string]float64 - isActiveByPoolByQueue map[string]map[string]bool + demandByPoolByQueue map[string]map[string]schedulerobjects.ResourceList totalCapacityByPool schedulerobjects.QuantityByTAndResourceType[string] jobsByExecutorId map[string][]*jobdb.Job nodeIdByJobId map[string]string @@ -297,13 +292,18 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con } // Create a map of jobs associated with each executor. - isActiveByPoolByQueue := make(map[string]map[string]bool, len(queues)) jobsByExecutorId := make(map[string][]*jobdb.Job) nodeIdByJobId := make(map[string]string) jobIdsByGangId := make(map[string]map[string]bool) gangIdByJobId := make(map[string]string) + demandByPoolByQueue := make(map[string]map[string]schedulerobjects.ResourceList) + for _, job := range txn.GetAll() { + if job.InTerminalState() { + continue + } + // Mark a queue being active for a given pool. A queue is defined as being active if it has a job running // on a pool or if a queued job is eligible for that pool pools := job.Pools() @@ -318,12 +318,17 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con } for _, pool := range pools { - isActiveByQueue, ok := isActiveByPoolByQueue[pool] + poolQueueResources, ok := demandByPoolByQueue[pool] + if !ok { + poolQueueResources = make(map[string]schedulerobjects.ResourceList, len(queues)) + demandByPoolByQueue[pool] = poolQueueResources + } + queueResources, ok := poolQueueResources[job.Queue()] if !ok { - isActiveByQueue = make(map[string]bool, len(queues)) + queueResources = schedulerobjects.NewResourceList(len(job.PodRequirements().ResourceRequirements.Requests)) + poolQueueResources[job.Queue()] = queueResources } - isActiveByQueue[job.Queue()] = true - isActiveByPoolByQueue[pool] = isActiveByQueue + queueResources.AddV1ResourceList(job.PodRequirements().ResourceRequirements.Requests) } if job.Queued() { @@ -371,7 +376,7 @@ func (l *FairSchedulingAlgo) newFairSchedulingAlgoContext(ctx *armadacontext.Con return &fairSchedulingAlgoContext{ queues: queues, priorityFactorByQueue: priorityFactorByQueue, - isActiveByPoolByQueue: isActiveByPoolByQueue, + demandByPoolByQueue: demandByPoolByQueue, totalCapacityByPool: totalCapacityByPool, jobsByExecutorId: jobsByExecutorId, nodeIdByJobId: nodeIdByJobId, @@ -437,14 +442,15 @@ func (l *FairSchedulingAlgo) scheduleOnExecutors( totalResources, ) - activeByQueue, ok := fsctx.isActiveByPoolByQueue[pool] + demandByQueue, ok := fsctx.demandByPoolByQueue[pool] if !ok { - activeByQueue = map[string]bool{} + demandByQueue = map[string]schedulerobjects.ResourceList{} } now := time.Now() for queue, priorityFactor := range fsctx.priorityFactorByQueue { - if !activeByQueue[queue] { + demand, hasDemand := demandByQueue[queue] + if !hasDemand { // To ensure fair share is computed only from active queues, i.e., queues with jobs queued or running. continue } @@ -474,10 +480,11 @@ func (l *FairSchedulingAlgo) scheduleOnExecutors( } queueLimiter.SetLimitAt(now, rate.Limit(l.schedulingConfig.MaximumPerQueueSchedulingRate*(1-quarantineFactor))) - if err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByPriorityClass, queueLimiter); err != nil { + if err := sctx.AddQueueSchedulingContext(queue, weight, allocatedByPriorityClass, demand, queueLimiter); err != nil { return nil, nil, err } } + sctx.UpdateFairShares() constraints := schedulerconstraints.NewSchedulingConstraints( pool, fsctx.totalCapacityByPool[pool], @@ -489,6 +496,7 @@ func (l *FairSchedulingAlgo) scheduleOnExecutors( sctx, constraints, l.schedulingConfig.ProtectedFractionOfFairShare, + l.schedulingConfig.UseAdjustedFairShareProtection, NewSchedulerJobRepositoryAdapter(fsctx.txn), nodeDb, fsctx.nodeIdByJobId, diff --git a/internal/scheduler/scheduling_algo_test.go b/internal/scheduler/scheduling_algo_test.go index c0b2e07cd3d..94825c2a287 100644 --- a/internal/scheduler/scheduling_algo_test.go +++ b/internal/scheduler/scheduling_algo_test.go @@ -493,6 +493,14 @@ func TestSchedule(t *testing.T) { dbJob := txn.GetById(job.Id()) assert.True(t, job.Equal(dbJob), "expected %v but got %v", job, dbJob) } + + // Check that we calculated fair share and adjusted fair share + for _, schCtx := range schedulerResult.SchedulingContexts { + for _, qtx := range schCtx.QueueSchedulingContexts { + assert.NotEqual(t, 0, qtx.AdjustedFairShare) + assert.NotEqual(t, 0, qtx.FairShare) + } + } }) } } diff --git a/internal/scheduler/simulator/simulator.go b/internal/scheduler/simulator/simulator.go index 592d71240fc..059e4ee0deb 100644 --- a/internal/scheduler/simulator/simulator.go +++ b/internal/scheduler/simulator/simulator.go @@ -443,6 +443,7 @@ func (s *Simulator) handleScheduleEvent(ctx *armadacontext.Context) error { var eventSequences []*armadaevents.EventSequence txn := s.jobDb.WriteTxn() defer txn.Abort() + demandByQueue := calculateDemandByQueue(txn.GetAll()) for _, pool := range s.ClusterSpec.Pools { for i := range pool.ClusterGroups { nodeDb := s.nodeDbByPoolAndExecutorGroup[pool.Name][i] @@ -469,6 +470,11 @@ func (s *Simulator) handleScheduleEvent(ctx *armadacontext.Context) error { sctx.Started = s.time for _, queue := range s.WorkloadSpec.Queues { + demand, hasDemand := demandByQueue[queue.Name] + if !hasDemand { + // To ensure fair share is computed only from active queues, i.e., queues with jobs queued or running. + continue + } limiter, ok := s.limiterByQueue[queue.Name] if !ok { limiter = rate.NewLimiter( @@ -482,6 +488,7 @@ func (s *Simulator) handleScheduleEvent(ctx *armadacontext.Context) error { queue.Name, queue.Weight, s.allocationByPoolAndQueueAndPriorityClass[pool.Name][queue.Name], + demand, limiter, ) if err != nil { @@ -500,6 +507,7 @@ func (s *Simulator) handleScheduleEvent(ctx *armadacontext.Context) error { sctx, constraints, s.schedulingConfig.ProtectedFractionOfFairShare, + s.schedulingConfig.UseAdjustedFairShareProtection, scheduler.NewSchedulerJobRepositoryAdapter(txn), nodeDb, // TODO: Necessary to support partial eviction. @@ -880,3 +888,20 @@ func maxTime(a, b time.Time) time.Time { func pointer[T any](t T) *T { return &t } + +func calculateDemandByQueue(jobs []*jobdb.Job) map[string]schedulerobjects.ResourceList { + queueResources := make(map[string]schedulerobjects.ResourceList) + + for _, job := range jobs { + if job.InTerminalState() { + continue + } + r, ok := queueResources[job.Queue()] + if !ok { + r = schedulerobjects.NewResourceList(len(job.PodRequirements().ResourceRequirements.Requests)) + queueResources[job.Queue()] = r + } + r.AddV1ResourceList(job.PodRequirements().ResourceRequirements.Requests) + } + return queueResources +} diff --git a/internal/scheduler/testfixtures/testfixtures.go b/internal/scheduler/testfixtures/testfixtures.go index dac90abbf39..36bac3dfb58 100644 --- a/internal/scheduler/testfixtures/testfixtures.go +++ b/internal/scheduler/testfixtures/testfixtures.go @@ -173,6 +173,7 @@ func TestSchedulingConfig() schedulerconfiguration.SchedulingConfig { ExecutorTimeout: 15 * time.Minute, MaxUnacknowledgedJobsPerExecutor: math.MaxInt, SupportedResourceTypes: GetTestSupportedResourceTypes(), + UseAdjustedFairShareProtection: true, } } diff --git a/internal/scheduleringester/config.go b/internal/scheduleringester/config.go index 67d5b945718..e08b5f64e1e 100644 --- a/internal/scheduleringester/config.go +++ b/internal/scheduleringester/config.go @@ -18,7 +18,7 @@ type Configuration struct { PriorityClasses map[string]types.PriorityClass // Pulsar subscription name SubscriptionName string - // Number of messages that will be batched together before being inserted into the database + // Number of event messages that will be batched together before being inserted into the database BatchSize int // Maximum time since the last batch before a batch will be inserted into the database BatchDuration time.Duration diff --git a/internal/scheduleringester/instructions.go b/internal/scheduleringester/instructions.go index 9e0468540d8..509a00f2d94 100644 --- a/internal/scheduleringester/instructions.go +++ b/internal/scheduleringester/instructions.go @@ -57,6 +57,7 @@ func (c *InstructionConverter) Convert(_ *armadacontext.Context, sequencesWithId operations = AppendDbOperation(operations, op) } } + log.Infof("Converted sequences into %d db operations", len(operations)) return &DbOperationsWithMessageIds{ Ops: operations, MessageIds: sequencesWithIds.MessageIds, @@ -117,6 +118,7 @@ func (c *InstructionConverter) dbOperationsFromEventSequence(es *armadaevents.Ev operationsFromEvent, err = c.handleJobValidated(event.GetJobValidated()) case *armadaevents.EventSequence_Event_ReprioritisedJob, *armadaevents.EventSequence_Event_ResourceUtilisation, + *armadaevents.EventSequence_Event_JobRunCancelled, *armadaevents.EventSequence_Event_StandaloneIngressInfo: // These events can all be safely ignored log.Debugf("Ignoring event type %T", event) diff --git a/magefiles/airflow.go b/magefiles/airflow.go index 8a1d970b5d7..bb160a64536 100644 --- a/magefiles/airflow.go +++ b/magefiles/airflow.go @@ -95,10 +95,5 @@ func AirflowOperator() error { return fmt.Errorf("failed to build Airflow Operator: %w", err) } - err = dockerRun("run", "--rm", "-v", "${PWD}/proto-airflow:/proto-airflow", "-v", "${PWD}:/go/src/armada", "-w", "/go/src/armada", "armada-airflow-operator-builder", "./scripts/build-airflow-operator.sh") - if err != nil { - return fmt.Errorf("failed to run build-airflow-operator.sh script: %w", err) - } - return nil } diff --git a/magefiles/tests.go b/magefiles/tests.go index 620aab1dd97..11cc3b2b6be 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -136,10 +136,6 @@ func runTest(name, outputFileName string) error { // Teste2eAirflow runs e2e tests for airflow func Teste2eAirflow() error { mg.Deps(AirflowOperator) - if err := BuildDockers("jobservice"); err != nil { - return err - } - cmd, err := go_CMD() if err != nil { return err @@ -149,29 +145,14 @@ func Teste2eAirflow() error { fmt.Println(err) } - if err := dockerRun("rm", "-f", "jobservice"); err != nil { - fmt.Println(err) - } - - err = dockerRun("run", "-d", "--name", "jobservice", "--network=kind", - "--mount", "type=bind,src=${PWD}/e2e,dst=/e2e", "gresearch/armada-jobservice", "run", "--config", - "/e2e/setup/jobservice.yaml") - if err != nil { - return err - } - err = dockerRun("run", "-v", "${PWD}/e2e:/e2e", "-v", "${PWD}/third_party/airflow:/code", - "--workdir", "/code", "-e", "ARMADA_SERVER=server", "-e", "ARMADA_PORT=50051", "-e", "JOB_SERVICE_HOST=jobservice", - "-e", "JOB_SERVICE_PORT=60003", "--entrypoint", "python3", "--network=kind", "armada-airflow-operator-builder:latest", - "-m", "pytest", "-v", "-s", "/code/tests/integration/test_airflow_operator_logic.py") + "--workdir", "/code", "-e", "ARMADA_SERVER=server", "-e", "ARMADA_PORT=50051", "--entrypoint", + "python3", "--network=kind", "armada-airflow-operator-builder:latest", + "-m", "pytest", "-v", "-s", "/code/test/integration/test_airflow_operator_logic.py") if err != nil { return err } - err = dockerRun("rm", "-f", "jobservice") - if err != nil { - return err - } return nil } diff --git a/pkg/armadaevents/events_util.go b/pkg/armadaevents/events_util.go index 0a142522e48..22e3cbf35ba 100644 --- a/pkg/armadaevents/events_util.go +++ b/pkg/armadaevents/events_util.go @@ -167,6 +167,58 @@ func (ev *EventSequence_Event) UnmarshalJSON(data []byte) error { return nil } +func (ev *EventSequence_Event) GetEventName() string { + switch ev.GetEvent().(type) { + case *EventSequence_Event_SubmitJob: + return "SubmitJob" + case *EventSequence_Event_JobRunLeased: + return "JobRunLeased" + case *EventSequence_Event_JobRunRunning: + return "JobRunRunning" + case *EventSequence_Event_JobRunSucceeded: + return "JobRunSucceeded" + case *EventSequence_Event_JobRunErrors: + return "JobRunErrors" + case *EventSequence_Event_JobSucceeded: + return "JobSucceeded" + case *EventSequence_Event_JobErrors: + return "JobErrors" + case *EventSequence_Event_JobPreemptionRequested: + return "JobPreemptionRequested" + case *EventSequence_Event_JobRunPreemptionRequested: + return "JobRunPreemptionRequested" + case *EventSequence_Event_ReprioritiseJob: + return "ReprioritiseJob" + case *EventSequence_Event_ReprioritiseJobSet: + return "ReprioritiseJobSet" + case *EventSequence_Event_CancelJob: + return "CancelJob" + case *EventSequence_Event_CancelJobSet: + return "CancelJobSet" + case *EventSequence_Event_CancelledJob: + return "CancelledJob" + case *EventSequence_Event_JobRunCancelled: + return "JobRunCancelled" + case *EventSequence_Event_JobRequeued: + return "JobRequeued" + case *EventSequence_Event_PartitionMarker: + return "PartitionMarker" + case *EventSequence_Event_JobRunPreempted: + return "JobRunPreemped" + case *EventSequence_Event_JobRunAssigned: + return "JobRunAssigned" + case *EventSequence_Event_JobValidated: + return "JobValidated" + case *EventSequence_Event_ReprioritisedJob: + return "ReprioritisedJob" + case *EventSequence_Event_ResourceUtilisation: + return "ResourceUtilisation" + case *EventSequence_Event_StandaloneIngressInfo: + return "StandloneIngressIngo" + } + return "Unknown" +} + func (kmo *KubernetesMainObject) UnmarshalJSON(data []byte) error { if string(data) == "null" || string(data) == `""` { return nil diff --git a/pkg/armadaevents/events_util_test.go b/pkg/armadaevents/events_util_test.go index 7a3b1255e02..fde5e1be7ed 100644 --- a/pkg/armadaevents/events_util_test.go +++ b/pkg/armadaevents/events_util_test.go @@ -68,7 +68,7 @@ func generateFullES() ([]byte, error) { Containers: []v1.Container{ { Name: "container1", - Image: "alpine:3.10", + Image: "alpine:3.20.1", Args: []string{"sleep", "5s"}, Resources: v1.ResourceRequirements{ Requests: v1.ResourceList{"cpu": cpu, "memory": memory}, diff --git a/pkg/client/queue/permission_verb.go b/pkg/client/queue/permission_verb.go index e1bd5642d21..b062ca6b5bc 100644 --- a/pkg/client/queue/permission_verb.go +++ b/pkg/client/queue/permission_verb.go @@ -18,10 +18,10 @@ const ( ) // NewPermissionVerb returns PermissionVerb from input string. If input string doesn't match -// one of allowed verb values ["submit", "cancel", "reprioritize", "watch"], and error is returned. +// one of allowed verb values ["submit", "cancel", "preempt", "reprioritize", "watch"], and error is returned. func NewPermissionVerb(in string) (PermissionVerb, error) { switch verb := PermissionVerb(in); verb { - case PermissionVerbSubmit, PermissionVerbCancel, PermissionVerbReprioritize, PermissionVerbWatch: + case PermissionVerbSubmit, PermissionVerbCancel, PermissionVerbPreempt, PermissionVerbReprioritize, PermissionVerbWatch: return verb, nil default: return "", fmt.Errorf("invalid queue permission verb: %s", in) @@ -76,6 +76,7 @@ func AllPermissionVerbs() PermissionVerbs { return []PermissionVerb{ PermissionVerbSubmit, PermissionVerbCancel, + PermissionVerbPreempt, PermissionVerbReprioritize, PermissionVerbWatch, } diff --git a/pkg/client/queue/permission_verb_test.go b/pkg/client/queue/permission_verb_test.go index 0882cc7118f..e4b9e7bc114 100644 --- a/pkg/client/queue/permission_verb_test.go +++ b/pkg/client/queue/permission_verb_test.go @@ -14,6 +14,7 @@ func TestPermissionVerbUnmarshal(t *testing.T) { Verbs: []PermissionVerb{ PermissionVerbCancel, PermissionVerbReprioritize, + PermissionVerbPreempt, PermissionVerbSubmit, PermissionVerbWatch, }, diff --git a/scripts/build-airflow-operator.sh b/scripts/build-airflow-operator.sh deleted file mode 100755 index 531839ca025..00000000000 --- a/scripts/build-airflow-operator.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# This script is intended to be run under the docker container at $ARMADADIR/build/python-api-client/ - -# make the python package armada.client, not pkg.api -mkdir -p proto-airflow -cp pkg/api/jobservice/jobservice.proto proto-airflow -sed -i 's/\([^\/]\)pkg\/api/\1jobservice/g' proto-airflow/*.proto - - -# generate python stubs -cd proto-airflow -python3 -m grpc_tools.protoc -I. --plugin=protoc-gen-mypy=$(which protoc-gen-mypy) --mypy_out=../third_party/airflow/armada/jobservice --python_out=../third_party/airflow/armada/jobservice --grpc_python_out=../third_party/airflow/armada/jobservice \ - jobservice.proto -cd .. -# This hideous code is because we can't use python package option in grpc. -# See https://github.com/protocolbuffers/protobuf/issues/7061 for an explanation. -# We need to import these packages as a module. -sed -i 's/import jobservice_pb2 as jobservice__pb2/from armada.jobservice import jobservice_pb2 as jobservice__pb2/g' third_party/airflow/armada/jobservice/*.py diff --git a/third_party/airflow/README.md b/third_party/airflow/README.md index 573b3861e5b..fe3bb0c8c18 100644 --- a/third_party/airflow/README.md +++ b/third_party/airflow/README.md @@ -1,12 +1,112 @@ # armada-airflow-operator -An Airflow operator for interfacing with the armada client - -## Background +Armada Airflow Operator, which manages airflow jobs. This allows Armada jobs to be run as part of an Airflow DAG + +## Overview + +The `ArmadaOperator` allows user to run an Armada Job as a task in an Airflow DAG. It handles job submission, job +state management and (optionally) log streaming back to Airflow. + +The Operator works by periodically polling Armada for the state of each job. As a result, it is only intended for DAGs +with tens or (at the limit) hundreds of concurrent jobs. + +## Installation + +`pip install armada-airflow` + +## Example Usage + +```python +from datetime import datetime + +from airflow import DAG +from armada_client.armada import submit_pb2 +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) + +from armada.operators.armada import ArmadaOperator + +def create_dummy_job(): + """ + Create a dummy job with a single container. + """ + + # For information on where this comes from, + # see https://github.com/kubernetes/api/blob/master/core/v1/generated.proto + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="alpine:3.20.1", + args=["sh", "-c", "for i in $(seq 1 60); do echo $i; sleep 1; done"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="1"), + "memory": api_resource.Quantity(string="1Gi"), + }, + limits={ + "cpu": api_resource.Quantity(string="1"), + "memory": api_resource.Quantity(string="1Gi"), + }, + ), + ) + ], + ) + + return submit_pb2.JobSubmitRequestItem( + priority=1, pod_spec=pod, namespace="armada" + ) + +armada_channel_args = {"target": "127.0.0.1:50051"} + + +with DAG( + "test_new_armada_operator", + description="Example DAG Showing Usage Of ArmadaOperator", + schedule=None, + start_date=datetime(2022, 1, 1), + catchup=False, +) as dag: + armada_task = ArmadaOperator( + name="non_deferrable_task", + task_id="1", + channel_args=armada_channel_args, + armada_queue="armada", + job_request=create_dummy_job(), + container_logs="sleep", + lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", + deferrable=False + ) + + armada_deferrable_task = ArmadaOperator( + name="deferrable_task", + task_id="2", + channel_args=armada_channel_args, + armada_queue="armada", + job_request=create_dummy_job(), + container_logs="sleep", + lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", + deferrable=True + ) + + armada_task >> armada_deferrable_task +``` +## Parameters -Airflow is an open source project focused on orchestrating Direct Acylic Graphs (DAGs) across different compute platforms. To interface Airflow with Armada, you should use our armada operator. +| Name | Description | Notes | +|----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------| +| channel_args | A list of key-value pairs ([channel_arguments](https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments) in gRPC runtime) to configure the channel. | None | +| armada_queue | Armada queue to be used for the job | Make sure that Airflow user is permissioned on this queue | +| job_request | A `JobSubmitRequestItem` that is to be submitted to Armada as part of this task | Object contains a `core_v1.PodSpec` within it | +| job_set_prefix | A prefix for the JobSet name provided to Armada when submitting the job | The JobSet name submitted will be the Airflow `run_id` prefixed with this provided prefix | +| poll_interval | Integer number of seconds representing how ofter Airflow will poll Armada for Job Status. Defaults to 30 Seconds | Decreasing this makes the operator more responsive but comes at the cost of increased load on the Armada Server. Please do not decrease below 10 seconds. | +| container_logs | Name of the container in your job from which you wish to stream logs. If unset then no logs will be streamed | Only use this if you are running relatively few (<50) concurrent jobs | +| deferrable | Flag to specify whether to run the operator in Airflow Deferrable Mode | Defaults to True | -## Airflow +# Contributing The [airflow documentation](https://airflow.apache.org/) was used for setting up a simple test server. @@ -48,13 +148,13 @@ You can install the package via `pip3 install third_party/airflow`. You can use our tox file that streamlines development lifecycle. For development, you can install black, tox, mypy and flake8. -`python3.8 -m tox -e py38` will run unit tests. +`python3.10 -m tox -e py310` will run unit tests. -`python3.8 -m tox -e format` will run a format check +`python3.10 -m tox -e format` will run black on your code. -`python3.8 -m tox -e format-code` will run black on your code. +`python3.10 -m tox -e format-check` will run a format check. -`python3.8 -m tox -e docs` will generate a new sphinx doc. +`python3.10 -m tox -e docs` will generate a new sphinx doc. ## Releasing the client Armada-airflow releases are automated via Github Actions, for contributors with sufficient access to run them. diff --git a/third_party/airflow/armada/auth.py b/third_party/airflow/armada/auth.py new file mode 100644 index 00000000000..ca90b521ecf --- /dev/null +++ b/third_party/airflow/armada/auth.py @@ -0,0 +1,11 @@ +from typing import Dict, Any, Tuple, Protocol + + +""" We use this interface for objects fetching Kubernetes auth tokens. Since + it's used within the Trigger, it must be serialisable.""" + + +class TokenRetriever(Protocol): + def get_token(self) -> str: ... + + def serialize(self) -> Tuple[str, Dict[str, Any]]: ... diff --git a/third_party/airflow/armada/logs/log_consumer.py b/third_party/airflow/armada/logs/log_consumer.py new file mode 100644 index 00000000000..8fd8c31d3ef --- /dev/null +++ b/third_party/airflow/armada/logs/log_consumer.py @@ -0,0 +1,252 @@ +# Copyright 2016-2024 The Apache Software Foundation +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import queue +from datetime import timedelta +from http.client import HTTPResponse +from typing import Generator, TYPE_CHECKING, Callable, Awaitable, List + +from aiohttp.client_exceptions import ClientResponse +from airflow.utils.timezone import utcnow +from kubernetes.client import V1Pod +from kubernetes_asyncio.client import V1Pod as aio_V1Pod + +from armada.logs.utils import container_is_running, get_container_status + +if TYPE_CHECKING: + from urllib3.response import HTTPResponse # noqa: F811 + + +class PodLogsConsumerAsync: + """ + Responsible for pulling pod logs from a stream asynchronously, checking the + container status before reading data. + + This class contains a workaround for the issue + https://github.com/apache/airflow/issues/23497. + + :param response: HTTP response with logs + :param pod: Pod instance from Kubernetes client + :param read_pod_async: Callable returning a pod object that can be awaited on, + given (pod name, namespace) as arguments + :param container_name: Name of the container that we're reading logs from + :param post_termination_timeout: (Optional) The period of time in seconds + representing for how long time + logs are available after the container termination. + :param read_pod_cache_timeout: (Optional) The container's status cache lifetime. + The container status is cached to reduce API calls. + + :meta private: + """ + + def __init__( + self, + response: ClientResponse, + pod_name: str, + namespace: str, + read_pod_async: Callable[[str, str], Awaitable[aio_V1Pod]], + container_name: str, + post_termination_timeout: int = 120, + read_pod_cache_timeout: int = 120, + ): + self.response = response + self.pod_name = pod_name + self.namespace = namespace + self._read_pod_async = read_pod_async + self.container_name = container_name + self.post_termination_timeout = post_termination_timeout + self.last_read_pod_at = None + self.read_pod_cache = None + self.read_pod_cache_timeout = read_pod_cache_timeout + self.log_queue = queue.Queue() + + def __aiter__(self): + return self + + async def __anext__(self): + r"""Yield log items divided by the '\n' symbol.""" + if not self.log_queue.empty(): + return self.log_queue.get() + + incomplete_log_item: List[bytes] = [] + if await self.logs_available(): + async for data_chunk in self.response.content: + if b"\n" in data_chunk: + log_items = data_chunk.split(b"\n") + for x in self._extract_log_items(incomplete_log_item, log_items): + if x is not None: + self.log_queue.put(x) + incomplete_log_item = self._save_incomplete_log_item(log_items[-1]) + else: + incomplete_log_item.append(data_chunk) + if not await self.logs_available(): + break + else: + self.response.close() + raise StopAsyncIteration + if incomplete_log_item: + item = b"".join(incomplete_log_item) + if item is not None: + self.log_queue.put(item) + + # Prevents method from returning None + if not self.log_queue.empty(): + return self.log_queue.get() + + self.response.close() + raise StopAsyncIteration + + @staticmethod + def _extract_log_items(incomplete_log_item: List[bytes], log_items: List[bytes]): + yield b"".join(incomplete_log_item) + log_items[0] + b"\n" + for x in log_items[1:-1]: + yield x + b"\n" + + @staticmethod + def _save_incomplete_log_item(sub_chunk: bytes): + return [sub_chunk] if [sub_chunk] else [] + + async def logs_available(self): + remote_pod = await self.read_pod() + if container_is_running(pod=remote_pod, container_name=self.container_name): + return True + container_status = get_container_status( + pod=remote_pod, container_name=self.container_name + ) + state = container_status.state if container_status else None + terminated = state.terminated if state else None + if terminated: + termination_time = terminated.finished_at + if termination_time: + return ( + termination_time + timedelta(seconds=self.post_termination_timeout) + > utcnow() + ) + return False + + async def read_pod(self): + _now = utcnow() + if ( + self.read_pod_cache is None + or self.last_read_pod_at + timedelta(seconds=self.read_pod_cache_timeout) + < _now + ): + self.read_pod_cache = await self._read_pod_async( + self.pod_name, self.namespace + ) + self.last_read_pod_at = _now + return self.read_pod_cache + + +class PodLogsConsumer: + """ + Responsible for pulling pod logs from a stream with checking a container status + before reading data. + + This class is a workaround for the issue + https://github.com/apache/airflow/issues/23497. + + :param response: HTTP response with logs + :param pod: Pod instance from Kubernetes client + :param read_pod: Callable returning a pod object given (pod name, namespace) as + arguments + :param container_name: Name of the container that we're reading logs from + :param post_termination_timeout: (Optional) The period of time in seconds + representing for how long time + logs are available after the container termination. + :param read_pod_cache_timeout: (Optional) The container's status cache lifetime. + The container status is cached to reduce API calls. + + :meta private: + """ + + def __init__( + self, + response: HTTPResponse, + pod_name: str, + namespace: str, + read_pod: Callable[[str, str], V1Pod], + container_name: str, + post_termination_timeout: int = 120, + read_pod_cache_timeout: int = 120, + ): + self.response = response + self.pod_name = pod_name + self.namespace = namespace + self._read_pod = read_pod + self.container_name = container_name + self.post_termination_timeout = post_termination_timeout + self.last_read_pod_at = None + self.read_pod_cache = None + self.read_pod_cache_timeout = read_pod_cache_timeout + + def __iter__(self) -> Generator[bytes, None, None]: + r"""Yield log items divided by the '\n' symbol.""" + incomplete_log_item: List[bytes] = [] + if self.logs_available(): + for data_chunk in self.response.stream(amt=None, decode_content=True): + if b"\n" in data_chunk: + log_items = data_chunk.split(b"\n") + yield from self._extract_log_items(incomplete_log_item, log_items) + incomplete_log_item = self._save_incomplete_log_item(log_items[-1]) + else: + incomplete_log_item.append(data_chunk) + if not self.logs_available(): + break + if incomplete_log_item: + yield b"".join(incomplete_log_item) + + @staticmethod + def _extract_log_items(incomplete_log_item: List[bytes], log_items: List[bytes]): + yield b"".join(incomplete_log_item) + log_items[0] + b"\n" + for x in log_items[1:-1]: + yield x + b"\n" + + @staticmethod + def _save_incomplete_log_item(sub_chunk: bytes): + return [sub_chunk] if [sub_chunk] else [] + + def logs_available(self): + remote_pod = self.read_pod() + if container_is_running(pod=remote_pod, container_name=self.container_name): + return True + container_status = get_container_status( + pod=remote_pod, container_name=self.container_name + ) + state = container_status.state if container_status else None + terminated = state.terminated if state else None + if terminated: + termination_time = terminated.finished_at + if termination_time: + return ( + termination_time + timedelta(seconds=self.post_termination_timeout) + > utcnow() + ) + return False + + def read_pod(self): + _now = utcnow() + if ( + self.read_pod_cache is None + or self.last_read_pod_at + timedelta(seconds=self.read_pod_cache_timeout) + < _now + ): + self.read_pod_cache = self._read_pod(self.pod_name, self.namespace) + self.last_read_pod_at = _now + return self.read_pod_cache diff --git a/third_party/airflow/armada/logs/pod_log_manager.py b/third_party/airflow/armada/logs/pod_log_manager.py new file mode 100644 index 00000000000..20e8e51c852 --- /dev/null +++ b/third_party/airflow/armada/logs/pod_log_manager.py @@ -0,0 +1,550 @@ +# Copyright 2016-2024 The Apache Software Foundation +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import math +import time +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, cast, Optional + +import pendulum +import tenacity +from kubernetes import client, watch, config +from kubernetes_asyncio import client as async_client, config as async_config +from kubernetes.client.rest import ApiException +from pendulum import DateTime +from pendulum.parsing.exceptions import ParserError +from urllib3.exceptions import HTTPError as BaseHTTPError + +from airflow.exceptions import AirflowException +from airflow.utils.log.logging_mixin import LoggingMixin + +from armada.auth import TokenRetriever +from armada.logs.log_consumer import PodLogsConsumer, PodLogsConsumerAsync +from armada.logs.utils import container_is_running + +if TYPE_CHECKING: + from kubernetes.client.models.v1_pod import V1Pod + + +class PodPhase: + """ + Possible pod phases. + + See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase. + """ + + PENDING = "Pending" + RUNNING = "Running" + FAILED = "Failed" + SUCCEEDED = "Succeeded" + + terminal_states = {FAILED, SUCCEEDED} + + +@dataclass +class PodLoggingStatus: + """Return the status of the pod and last log time when exiting from + `fetch_container_logs`.""" + + running: bool + last_log_time: DateTime | None + + +class PodLogManagerAsync(LoggingMixin): + """Monitor logs of Kubernetes pods asynchronously.""" + + def __init__( + self, + k8s_context: str, + token_retriever: Optional[TokenRetriever] = None, + ): + """ + Create the launcher. + + :param k8s_context: kubernetes context + :param token_retriever: Retrieves auth tokens + """ + super().__init__() + self._k8s_context = k8s_context + self._watch = watch.Watch() + self._k8s_client = None + self._token_retriever = token_retriever + + async def _refresh_k8s_auth_token(self, interval=60 * 5): + if self._token_retriever is not None: + while True: + await asyncio.sleep(interval) + self._k8s_client.api_client.configuration.api_key["BearerToken"] = ( + f"Bearer {self._token_retriever.get_token()}" + ) + + async def k8s_client(self) -> async_client: + await async_config.load_kube_config(context=self._k8s_context) + asyncio.create_task(self._refresh_k8s_auth_token()) + return async_client.CoreV1Api() + + async def fetch_container_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + *, + follow=False, + since_time: DateTime | None = None, + post_termination_timeout: int = 120, + ) -> PodLoggingStatus: + """ + Follow the logs of container and stream to airflow logging. Doesn't block whilst + logs are being fetched. + + Returns when container exits. + + Between when the pod starts and logs being available, there might be a delay due + to CSR not approved + and signed yet. In such situation, ApiException is thrown. This is why we are + retrying on this + specific exception. + """ + # Can't await in constructor, so instantiating here + if self._k8s_client is None: + self._k8s_client = await self.k8s_client() + + @tenacity.retry( + retry=tenacity.retry_if_exception_type(ApiException), + stop=tenacity.stop_after_attempt(10), + wait=tenacity.wait_fixed(1), + ) + async def consume_logs( + *, + since_time: DateTime | None = None, + follow: bool = True, + logs: PodLogsConsumerAsync | None, + ) -> tuple[DateTime | None, PodLogsConsumerAsync | None]: + """ + Try to follow container logs until container completes. + + For a long-running container, sometimes the log read may be interrupted + Such errors of this kind are suppressed. + + Returns the last timestamp observed in logs. + """ + last_captured_timestamp = None + try: + logs = await self._read_pod_logs( + pod_name=pod_name, + namespace=namespace, + container_name=container_name, + timestamps=True, + since_seconds=( + math.ceil((pendulum.now() - since_time).total_seconds()) + if since_time + else None + ), + follow=follow, + post_termination_timeout=post_termination_timeout, + ) + message_to_log = None + message_timestamp = None + progress_callback_lines = [] + try: + async for raw_line in logs: + line = raw_line.decode("utf-8", errors="backslashreplace") + line_timestamp, message = self._parse_log_line(line) + if line_timestamp: # detect new log line + if message_to_log is None: # first line in the log + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines.append(line) + else: # previous log line is complete + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines = [line] + else: # continuation of the previous log line + message_to_log = f"{message_to_log}\n{message}" + progress_callback_lines.append(line) + finally: + if message_to_log is not None: + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + except BaseHTTPError as e: + self.log.warning( + "Reading of logs interrupted for container %r with error %r; will " + "retry. " + "Set log level to DEBUG for traceback.", + container_name, + e, + ) + self.log.debug( + "Traceback for interrupted logs read for pod %r", + pod_name, + exc_info=True, + ) + return last_captured_timestamp or since_time, logs + + # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to + # loop as we do here. But in a long-running process we might temporarily lose + # connectivity. + # So the looping logic is there to let us resume following the logs. + logs = None + last_log_time = since_time + while True: + last_log_time, logs = await consume_logs( + since_time=last_log_time, + follow=follow, + logs=logs, + ) + if not await self._container_is_running_async( + pod_name, namespace, container_name=container_name + ): + return PodLoggingStatus(running=False, last_log_time=last_log_time) + if not follow: + return PodLoggingStatus(running=True, last_log_time=last_log_time) + else: + self.log.warning( + "Pod %s log read interrupted but container %s still running", + pod_name, + container_name, + ) + time.sleep(1) + + def _parse_log_line(self, line: str) -> tuple[DateTime | None, str]: + """ + Parse K8s log line and returns the final state. + + :param line: k8s log line + :return: timestamp and log message + """ + timestamp, sep, message = line.strip().partition(" ") + if not sep: + return None, line + try: + last_log_time = cast(DateTime, pendulum.parse(timestamp)) + except ParserError: + return None, line + return last_log_time, message + + async def _container_is_running_async( + self, pod_name: str, namespace: str, container_name: str + ) -> bool: + """Read pod and checks if container is running.""" + remote_pod = await self.read_pod(pod_name, namespace) + return container_is_running(pod=remote_pod, container_name=container_name) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + async def _read_pod_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + tail_lines: int | None = None, + timestamps: bool = False, + since_seconds: int | None = None, + follow=True, + post_termination_timeout: int = 120, + ) -> PodLogsConsumerAsync: + """Read log from the POD.""" + additional_kwargs = {} + if since_seconds: + additional_kwargs["since_seconds"] = since_seconds + + if tail_lines: + additional_kwargs["tail_lines"] = tail_lines + + try: + logs = await self._k8s_client.read_namespaced_pod_log( + name=pod_name, + namespace=namespace, + container=container_name, + follow=follow, + timestamps=timestamps, + _preload_content=False, + **additional_kwargs, + ) + except BaseHTTPError: + self.log.exception("There was an error reading the kubernetes API.") + raise + + return PodLogsConsumerAsync( + response=logs, + pod_name=pod_name, + namespace=namespace, + read_pod_async=self.read_pod, + container_name=container_name, + post_termination_timeout=post_termination_timeout, + ) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + async def read_pod(self, pod_name: str, namespace: str) -> V1Pod: + """Read POD information.""" + try: + return await self._k8s_client.read_namespaced_pod(pod_name, namespace) + except BaseHTTPError as e: + raise AirflowException( + f"There was an error reading the kubernetes API: {e}" + ) + + +class PodLogManager(LoggingMixin): + """Monitor logs of Kubernetes pods.""" + + def __init__( + self, k8s_context: str, token_retriever: Optional[TokenRetriever] = None + ): + """ + Create the launcher. + + :param k8s_context: kubernetes context + :param token_retriever: Retrieves auth tokens + """ + super().__init__() + self._k8s_context = k8s_context + self._watch = watch.Watch() + self._token_retriever = token_retriever + + def _refresh_k8s_auth_token(self): + if self._token_retriever is not None: + self._k8s_client.api_client.configuration.api_key["BearerToken"] = ( + f"Bearer {self._token_retriever.get_token()}" + ) + + @cached_property + def _k8s_client(self) -> client: + config.load_kube_config(context=self._k8s_context) + return client.CoreV1Api() + + def fetch_container_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + *, + follow=False, + since_time: DateTime | None = None, + post_termination_timeout: int = 120, + ) -> PodLoggingStatus: + """ + Follow the logs of container and stream to airflow logging. + + Returns when container exits. + + Between when the pod starts and logs being available, there might be a delay due + to CSR not approved + and signed yet. In such situation, ApiException is thrown. This is why we are + retrying on this + specific exception. + """ + + @tenacity.retry( + retry=tenacity.retry_if_exception_type(ApiException), + stop=tenacity.stop_after_attempt(10), + wait=tenacity.wait_fixed(1), + ) + def consume_logs( + *, + since_time: DateTime | None = None, + follow: bool = True, + logs: PodLogsConsumer | None, + ) -> tuple[DateTime | None, PodLogsConsumer | None]: + """ + Try to follow container logs until container completes. + + For a long-running container, sometimes the log read may be interrupted + Such errors of this kind are suppressed. + + Returns the last timestamp observed in logs. + """ + last_captured_timestamp = None + try: + logs = self._read_pod_logs( + pod_name=pod_name, + namespace=namespace, + container_name=container_name, + timestamps=True, + since_seconds=( + math.ceil((pendulum.now() - since_time).total_seconds()) + if since_time + else None + ), + follow=follow, + post_termination_timeout=post_termination_timeout, + ) + message_to_log = None + message_timestamp = None + progress_callback_lines = [] + try: + for raw_line in logs: + line = raw_line.decode("utf-8", errors="backslashreplace") + line_timestamp, message = self._parse_log_line(line) + if line_timestamp: # detect new log line + if message_to_log is None: # first line in the log + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines.append(line) + else: # previous log line is complete + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines = [line] + else: # continuation of the previous log line + message_to_log = f"{message_to_log}\n{message}" + progress_callback_lines.append(line) + finally: + if message_to_log is not None: + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + except BaseHTTPError as e: + self.log.warning( + "Reading of logs interrupted for container %r with error %r; will " + "retry. " + "Set log level to DEBUG for traceback.", + container_name, + e, + ) + self.log.debug( + "Traceback for interrupted logs read for pod %r", + pod_name, + exc_info=True, + ) + return last_captured_timestamp or since_time, logs + + # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to + # loop as we do here. But in a long-running process we might temporarily lose + # connectivity. + # So the looping logic is there to let us resume following the logs. + logs = None + last_log_time = since_time + while True: + last_log_time, logs = consume_logs( + since_time=last_log_time, + follow=follow, + logs=logs, + ) + if not self._container_is_running( + pod_name, namespace, container_name=container_name + ): + return PodLoggingStatus(running=False, last_log_time=last_log_time) + if not follow: + return PodLoggingStatus(running=True, last_log_time=last_log_time) + else: + self.log.warning( + "Pod %s log read interrupted but container %s still running", + pod_name, + container_name, + ) + time.sleep(1) + self._refresh_k8s_auth_token() + + def _parse_log_line(self, line: str) -> tuple[DateTime | None, str]: + """ + Parse K8s log line and returns the final state. + + :param line: k8s log line + :return: timestamp and log message + """ + timestamp, sep, message = line.strip().partition(" ") + if not sep: + return None, line + try: + last_log_time = cast(DateTime, pendulum.parse(timestamp)) + except ParserError: + return None, line + return last_log_time, message + + def _container_is_running( + self, pod_name: str, namespace: str, container_name: str + ) -> bool: + """Read pod and checks if container is running.""" + remote_pod = self.read_pod(pod_name, namespace) + return container_is_running(pod=remote_pod, container_name=container_name) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + def _read_pod_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + tail_lines: int | None = None, + timestamps: bool = False, + since_seconds: int | None = None, + follow=True, + post_termination_timeout: int = 120, + ) -> PodLogsConsumer: + """Read log from the POD.""" + additional_kwargs = {} + if since_seconds: + additional_kwargs["since_seconds"] = since_seconds + + if tail_lines: + additional_kwargs["tail_lines"] = tail_lines + + try: + logs = self._k8s_client.read_namespaced_pod_log( + name=pod_name, + namespace=namespace, + container=container_name, + follow=follow, + timestamps=timestamps, + _preload_content=False, + **additional_kwargs, + ) + except BaseHTTPError: + self.log.exception("There was an error reading the kubernetes API.") + raise + + return PodLogsConsumer( + response=logs, + pod_name=pod_name, + namespace=namespace, + read_pod=self.read_pod, + container_name=container_name, + post_termination_timeout=post_termination_timeout, + ) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + def read_pod(self, pod_name: str, namespace: str) -> V1Pod: + """Read POD information.""" + try: + return self._k8s_client.read_namespaced_pod(pod_name, namespace) + except BaseHTTPError as e: + raise AirflowException( + f"There was an error reading the kubernetes API: {e}" + ) diff --git a/third_party/airflow/armada/logs/utils.py b/third_party/airflow/armada/logs/utils.py new file mode 100644 index 00000000000..ade71ba5fbe --- /dev/null +++ b/third_party/airflow/armada/logs/utils.py @@ -0,0 +1,55 @@ +# Copyright 2016-2024 The Apache Software Foundation +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING + +from kubernetes.client import V1Pod, V1ContainerStatus + +if TYPE_CHECKING: + from kubernetes.client.models.v1_container_status import ( # noqa: F811 + V1ContainerStatus, + ) + from kubernetes.client.models.v1_pod import V1Pod # noqa: F811 + + +def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus: + """Retrieve container status.""" + container_statuses = pod.status.container_statuses if pod and pod.status else None + if container_statuses: + # In general the variable container_statuses can store multiple items matching + # different containers. + # The following generator expression yields all items that have name equal to + # the container_name. + # The function next() here calls the generator to get only the first value. If + # there's nothing found + # then None is returned. + return next((x for x in container_statuses if x.name == container_name), None) + return None + + +def container_is_running(pod: V1Pod, container_name: str) -> bool: + """ + Examine V1Pod ``pod`` to determine whether ``container_name`` is running. + + If that container is present and running, returns True. Returns False otherwise. + """ + container_status = get_container_status(pod, container_name) + if not container_status: + return False + return container_status.state.running is not None diff --git a/third_party/airflow/armada/model.py b/third_party/airflow/armada/model.py new file mode 100644 index 00000000000..80e6e0d0a77 --- /dev/null +++ b/third_party/airflow/armada/model.py @@ -0,0 +1,83 @@ +import importlib +from typing import Tuple, Any, Optional, Sequence, Dict + +import grpc + + +""" This class exists so that we can retain our connection to the Armada Query API + when using the deferrable Armada Airflow Operator. Airflow requires any state + within deferrable operators be serialisable, unfortunately grpc.Channel isn't + itself serialisable.""" + + +class GrpcChannelArgs: + def __init__( + self, + target: str, + options: Optional[Sequence[Tuple[str, Any]]] = None, + compression: Optional[grpc.Compression] = None, + auth: Optional[grpc.AuthMetadataPlugin] = None, + auth_details: Optional[Dict[str, Any]] = None, + ): + self.target = target + self.options = options + self.compression = compression + if auth: + self.auth = auth + elif auth_details: + classpath, kwargs = auth_details + module_path, class_name = classpath.rsplit( + ".", 1 + ) # Split the classpath to module and class name + module = importlib.import_module( + module_path + ) # Dynamically import the module + cls = getattr(module, class_name) # Get the class from the module + self.auth = cls( + **kwargs + ) # Instantiate the class with the deserialized kwargs + else: + self.auth = None + + def serialize(self) -> Dict[str, Any]: + auth_details = self.auth.serialize() if self.auth else None + return { + "target": self.target, + "options": self.options, + "compression": self.compression, + "auth_details": auth_details, + } + + def channel(self) -> grpc.Channel: + if self.auth is None: + return grpc.insecure_channel( + target=self.target, options=self.options, compression=self.compression + ) + + return grpc.secure_channel( + target=self.target, + options=self.options, + compression=self.compression, + credentials=grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), + grpc.metadata_call_credentials(self.auth), + ), + ) + + def aio_channel(self) -> grpc.aio.Channel: + if self.auth is None: + return grpc.aio.insecure_channel( + target=self.target, + options=self.options, + compression=self.compression, + ) + + return grpc.aio.secure_channel( + target=self.target, + options=self.options, + compression=self.compression, + credentials=grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), + grpc.metadata_call_credentials(self.auth), + ), + ) diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 33475651275..cb9fd361c27 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -16,151 +16,331 @@ # specific language governing permissions and limitations # under the License. -import logging -from typing import Optional, List, Sequence +import os +import time +from functools import lru_cache, cached_property +from typing import Optional, Sequence, Any, Dict -from airflow.models import BaseOperator +import jinja2 +from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.utils.context import Context +from airflow.models import BaseOperator +from airflow.utils.context import Context +from airflow.utils.log.logging_mixin import LoggingMixin +from armada_client.armada.job_pb2 import JobRunDetails +from armada_client.typings import JobState from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.client import ArmadaClient - -from armada.operators.grpc import GrpcChannelArgsDict, GrpcChannelArguments -from armada.operators.jobservice import ( - JobServiceClient, - default_jobservice_channel_options, -) -from armada.operators.utils import ( - airflow_error, - search_for_job_complete, - annotate_job_request_items, -) -from armada.jobservice import jobservice_pb2 - from google.protobuf.json_format import MessageToDict, ParseDict -import jinja2 - +from armada_client.client import ArmadaClient +from armada.auth import TokenRetriever +from armada.logs.pod_log_manager import PodLogManager +from armada.model import GrpcChannelArgs +from armada.triggers.armada import ArmadaTrigger -armada_logger = logging.getLogger("airflow.task") +class ArmadaOperator(BaseOperator, LoggingMixin): + """ + An Airflow operator that manages Job submission to Armada. -class ArmadaOperator(BaseOperator): + This operator submits a job to an Armada cluster, polls for its completion, + and handles job cancellation if the Airflow task is killed. """ - Implementation of an ArmadaOperator for airflow. - - Airflow operators inherit from BaseOperator. - - :param name: The name of the airflow task - :param armada_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. - :param job_service_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. - :param armada_queue: The queue name for Armada. - :param job_request_items: A PodSpec that is used by Armada for submitting a job - :param lookout_url_template: A URL template to be used to provide users - a valid link to the related lookout job in this operator's log. - The format should be: - "https://lookout.armada.domain/jobs?job_id=" where will - be replaced with the actual job ID. - :param poll_interval: How often to poll jobservice to get status. - :return: an armada operator instance + + template_fields: Sequence[str] = ("job_request", "job_set_prefix") + """ +Initializes a new ArmadaOperator. - template_fields: Sequence[str] = ("job_request_items",) +:param name: The name of the job to be submitted. +:type name: str +:param channel_args: The gRPC channel arguments for connecting to the Armada server. +:type channel_args: GrpcChannelArgs +:param armada_queue: The name of the Armada queue to which the job will be submitted. +:type armada_queue: str +:param job_request: The job to be submitted to Armada. +:type job_request: JobSubmitRequestItem +:param job_set_prefix: A string to prepend to the jobSet name +:type job_set_prefix: Optional[str] +:param lookout_url_template: Template for creating lookout links. If not specified +then no tracking information will be logged. +:type lookout_url_template: Optional[str] +:param poll_interval: The interval in seconds between polling for job status updates. +:type poll_interval: int +:param container_logs: Name of container whose logs will be published to stdout. +:type container_logs: Optional[str] +:param k8s_token_retriever: A serialisable Kubernetes token retriever object. We use +this to read logs from Kubernetes pods. +:type k8s_token_retriever: Optional[TokenRetriever] +:param deferrable: Whether the operator should run in a deferrable mode, allowing +for asynchronous execution. +:type deferrable: bool +:param job_acknowledgement_timeout: The timeout in seconds to wait for a job to be +acknowledged by Armada. +:type job_acknowledgement_timeout: int +:param kwargs: Additional keyword arguments to pass to the BaseOperator. +""" def __init__( self, name: str, - armada_channel_args: GrpcChannelArgsDict, - job_service_channel_args: GrpcChannelArgsDict, + channel_args: GrpcChannelArgs, armada_queue: str, - job_request_items: List[JobSubmitRequestItem], + job_request: JobSubmitRequestItem, + job_set_prefix: Optional[str] = "", lookout_url_template: Optional[str] = None, poll_interval: int = 30, + container_logs: Optional[str] = None, + k8s_token_retriever: Optional[TokenRetriever] = None, + deferrable: bool = conf.getboolean( + "operators", "default_deferrable", fallback=False + ), + job_acknowledgement_timeout: int = 5 * 60, **kwargs, ) -> None: super().__init__(**kwargs) self.name = name - self.armada_channel_args = GrpcChannelArguments(**armada_channel_args) - - if "options" not in job_service_channel_args: - job_service_channel_args["options"] = default_jobservice_channel_options - - self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args) + self.channel_args = channel_args self.armada_queue = armada_queue - self.job_request_items = job_request_items + self.job_request = job_request + self.job_set_prefix = job_set_prefix self.lookout_url_template = lookout_url_template self.poll_interval = poll_interval + self.container_logs = container_logs + self.k8s_token_retriever = k8s_token_retriever + self.deferrable = deferrable + self.job_acknowledgement_timeout = job_acknowledgement_timeout + self.job_id = None + self.job_set_id = None + + if self.container_logs and self.k8s_token_retriever is None: + self.log.warning( + "Token refresh mechanism not configured, airflow may stop retrieving " + "logs from Kubernetes" + ) def execute(self, context) -> None: """ - Executes the Armada Operator. - - Runs an Armada job and calls the job_service_client for polling. + Submits the job to Armada and polls for completion. - :param context: The airflow context. - - :return: None + :param context: The execution context provided by Airflow. + :type context: Context """ - job_service_client = JobServiceClient(self.job_service_channel_args.channel()) - # Health Check - health = job_service_client.health() - if health.status != jobservice_pb2.HealthCheckResponse.SERVING: - armada_logger.warn("Armada Job Service is not health") - # This allows us to use a unique id from airflow - # and have all jobs in a dag correspond to same jobset - job_set_id = context["run_id"] - - armada_client = ArmadaClient(channel=self.armada_channel_args.channel()) - job = armada_client.submit_jobs( - queue=self.armada_queue, - job_set_id=job_set_id, - job_request_items=annotate_job_request_items( - context, self.job_request_items - ), - ) + # We take the job_set_id from Airflow's run_id. This means that all jobs in the + # dag will be in the same jobset. + self.job_set_id = f"{self.job_set_prefix}{context['run_id']}" + self._annotate_job_request(context, self.job_request) - try: - job_id = job.job_response_items[0].job_id - except Exception: - raise AirflowException("Armada has issues submitting job") - - armada_logger.info("Running Armada job %s with id %s", self.name, job_id) - - lookout_url = self._get_lookout_url(job_id) - if len(lookout_url) > 0: - armada_logger.info("Lookout URL: %s", lookout_url) - - job_state, job_message = search_for_job_complete( - job_service_client=job_service_client, - armada_queue=self.armada_queue, - job_set_id=job_set_id, - airflow_task_name=self.name, - job_id=job_id, - poll_interval=self.poll_interval, + # Submit job or reattach to previously submitted job. We always do this + # synchronously. + self.job_id = self._reattach_or_submit_job( + context, self.armada_queue, self.job_set_id, self.job_request ) - armada_logger.info( - "Armada Job finished with %s and message: %s", job_state, job_message - ) - airflow_error(job_state, self.name, job_id) - def _get_lookout_url(self, job_id: str) -> str: - if self.lookout_url_template is None: - return "" - return self.lookout_url_template.replace("", job_id) + # Wait until finished + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=ArmadaTrigger( + job_id=self.job_id, + armada_queue=self.armada_queue, + job_set_id=self.job_set_id, + channel_args=self.channel_args, + poll_interval=self.poll_interval, + tracking_message=self._trigger_tracking_message(), + job_acknowledgement_timeout=self.job_acknowledgement_timeout, + container_logs=self.container_logs, + k8s_token_retriever=self.k8s_token_retriever, + job_request_namespace=self.job_request.namespace, + ), + method_name="_execute_complete", + ) + else: + self._poll_for_termination(self._trigger_tracking_message()) + + @cached_property + def client(self) -> ArmadaClient: + return ArmadaClient(channel=self.channel_args.channel()) + + @lru_cache(maxsize=None) + def pod_manager(self, k8s_context: str) -> PodLogManager: + return PodLogManager( + k8s_context=k8s_context, token_retriever=self.k8s_token_retriever + ) def render_template_fields( self, context: Context, jinja_env: Optional[jinja2.Environment] = None, ) -> None: - self.job_request_items = [ - MessageToDict(x, preserving_proto_field_name=True) - for x in self.job_request_items - ] + """ + Template all attributes listed in self.template_fields. + This mutates the attributes in-place and is irreversible. + + Args: + context (Context): The execution context provided by Airflow. + :param context: Airflow Context dict wi1th values to apply on content + :param jinja_env: jinja’s environment to use for rendering. + """ + self.job_request = MessageToDict( + self.job_request, preserving_proto_field_name=True + ) super().render_template_fields(context, jinja_env) - self.job_request_items = [ - ParseDict(x, JobSubmitRequestItem()) for x in self.job_request_items - ] + self.job_request = ParseDict(self.job_request, JobSubmitRequestItem()) + + def _cancel_job(self) -> None: + try: + result = self.client.cancel_jobs( + queue=self.armada_queue, + job_set_id=self.job_set_id, + job_id=self.job_id, + ) + if len(list(result.cancelled_ids)) > 0: + self.log.info(f"Cancelled job with id {result.cancelled_ids}") + else: + self.log.warning(f"Failed to cancel job with id {self.job_id}") + except Exception as e: + self.log.warning(f"Failed to cancel job with id {self.job_id}: {e}") + + def on_kill(self) -> None: + if self.job_id is not None: + self.log.info( + f"on_kill called, cancelling job with id {self.job_id} in queue " + f"{self.armada_queue}" + ) + self._cancel_job() + + def _trigger_tracking_message(self): + if self.lookout_url_template: + return ( + f"Job details available at " + f'{self.lookout_url_template.replace("", self.job_id)}' + ) + + return "" + + def _execute_complete(self, _: Context, event: Dict[str, Any]): + if event["status"] == "error": + raise AirflowException(event["response"]) + + def _reattach_or_submit_job( + self, + context: Context, + queue: str, + job_set_id: str, + job_request: JobSubmitRequestItem, + ) -> str: + ti = context["ti"] + existing_id = ti.xcom_pull( + dag_id=ti.dag_id, task_ids=ti.task_id, key=f"{ti.try_number}" + ) + if existing_id is not None: + self.log.info( + f"Attached to existing job with id {existing_id['armada_job_id']}" + ) + return existing_id["armada_job_id"] + + job_id = self._submit_job(queue, job_set_id, job_request) + self.log.info(f"Submitted job with id {job_id}") + ti.xcom_push(key=f"{ti.try_number}", value={"armada_job_id": job_id}) + return job_id + + def _submit_job( + self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem + ) -> str: + resp = self.client.submit_jobs(queue, job_set_id, [job_request]) + num_responses = len(resp.job_response_items) + + # We submitted exactly one job to armada, so we expect a single response + if num_responses != 1: + raise AirflowException( + f"No valid received from Armada (expected 1 job to be created " + f"but got {num_responses}" + ) + job = resp.job_response_items[0] + + # Throw if armada told us we had submitted something bad + if job.error: + raise AirflowException(f"Error submitting job to Armada: {job.error}") + + return job.job_id + + def _poll_for_termination(self, tracking_message: str) -> None: + last_log_time = None + run_details = None + state = JobState.UNKNOWN + + start_time = time.time() + job_acknowledged = False + while state.is_active(): + response = self.client.get_job_status([self.job_id]) + state = JobState(response.job_states[self.job_id]) + self.log.info( + f"job {self.job_id} is in state: {state.name}. {tracking_message}" + ) + + if state != JobState.UNKNOWN: + job_acknowledged = True + + if ( + not job_acknowledged + and int(time.time() - start_time) > self.job_acknowledgement_timeout + ): + self.log.info( + f"Job {self.job_id} not acknowledged by the Armada server within " + f"timeout ({self.job_acknowledgement_timeout}), terminating" + ) + self.on_kill() + return + + if self.container_logs and not run_details: + if state == JobState.RUNNING or state.is_terminal(): + run_details = self._get_latest_job_run_details(self.job_id) + + if run_details: + try: + # pod_name format is sufficient for now. Ideally pod name should be + # retrieved from queryapi + log_status = self.pod_manager( + run_details.cluster + ).fetch_container_logs( + pod_name=f"armada-{self.job_id}-0", + namespace=self.job_request.namespace, + container_name=self.container_logs, + since_time=last_log_time, + ) + last_log_time = log_status.last_log_time + except Exception as e: + self.log.warning(f"Error fetching logs {e}") + + time.sleep(self.poll_interval) + + self.log.info(f"job {self.job_id} terminated with state: {state.name}") + if state != JobState.SUCCEEDED: + raise AirflowException( + f"job {self.job_id} did not succeed. Final status was {state.name}" + ) + + def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: + job_details = self.client.get_job_details([job_id]).job_details[job_id] + if job_details and job_details.latest_run_id: + for run in job_details.job_runs: + if run.run_id == job_details.latest_run_id: + return run + return None + + @staticmethod + def _annotate_job_request(context, request: JobSubmitRequestItem): + if "ANNOTATION_KEY_PREFIX" in os.environ: + annotation_key_prefix = f'{os.environ.get("ANNOTATION_KEY_PREFIX")}' + else: + annotation_key_prefix = "armadaproject.io/" + + task_id = context["ti"].task_id + run_id = context["run_id"] + dag_id = context["dag"].dag_id + + request.annotations[annotation_key_prefix + "taskId"] = task_id + request.annotations[annotation_key_prefix + "taskRunId"] = run_id + request.annotations[annotation_key_prefix + "dagId"] = dag_id diff --git a/third_party/airflow/armada/operators/armada_deferrable.py b/third_party/airflow/armada/operators/armada_deferrable.py deleted file mode 100644 index f7aa1413637..00000000000 --- a/third_party/airflow/armada/operators/armada_deferrable.py +++ /dev/null @@ -1,301 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from typing import Optional, Sequence, List - -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.context import Context - -from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.client import ArmadaClient - -from armada.operators.jobservice import ( - JobServiceClient, - default_jobservice_channel_options, -) -from armada.operators.grpc import GrpcChannelArgsDict, GrpcChannelArguments -from armada.operators.jobservice_asyncio import JobServiceAsyncIOClient -from armada.operators.utils import ( - airflow_error, - search_for_job_complete_async, - annotate_job_request_items, -) -from armada.jobservice import jobservice_pb2 - -from google.protobuf.json_format import MessageToDict, ParseDict - -import jinja2 - - -armada_logger = logging.getLogger("airflow.task") - - -class ArmadaDeferrableOperator(BaseOperator): - """ - Implementation of a deferrable armada operator for airflow. - - Distinguished from ArmadaOperator by its ability to defer itself after - submitting its job_request_items. - - See - https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html - for more information about deferrable airflow operators. - - Airflow operators inherit from BaseOperator. - - :param name: The name of the airflow task. - :param armada_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. - :param job_service_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. - :param armada_queue: The queue name for Armada. - :param job_request_items: A PodSpec that is used by Armada for submitting a job. - :param lookout_url_template: A URL template to be used to provide users - a valid link to the related lookout job in this operator's log. - The format should be: - "https://lookout.armada.domain/jobs?job_id=" where will - be replaced with the actual job ID. - :param poll_interval: How often to poll jobservice to get status. - :return: A deferrable armada operator instance. - """ - - template_fields: Sequence[str] = ("job_request_items",) - - def __init__( - self, - name: str, - armada_channel_args: GrpcChannelArgsDict, - job_service_channel_args: GrpcChannelArgsDict, - armada_queue: str, - job_request_items: List[JobSubmitRequestItem], - lookout_url_template: Optional[str] = None, - poll_interval: int = 30, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.name = name - self.armada_channel_args = GrpcChannelArguments(**armada_channel_args) - - if "options" not in job_service_channel_args: - job_service_channel_args["options"] = default_jobservice_channel_options - - self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args) - self.armada_queue = armada_queue - self.job_request_items = job_request_items - self.lookout_url_template = lookout_url_template - self.poll_interval = poll_interval - - def serialize(self) -> dict: - """ - Get a serialized version of this object. - - :return: A dict of keyword arguments used when instantiating - this object. - """ - - return { - "task_id": self.task_id, - "name": self.name, - "armada_channel_args": self.armada_channel_args.serialize(), - "job_service_channel_args": self.job_service_channel_args.serialize(), - "armada_queue": self.armada_queue, - "job_request_items": self.job_request_items, - "lookout_url_template": self.lookout_url_template, - "poll_interval": self.poll_interval, - } - - def execute(self, context) -> None: - """ - Executes the Armada Operator. Only meant to be called by airflow. - - Submits an Armada job and defers itself to ArmadaJobCompleteTrigger to wait - until the job completes. - - :param context: The airflow context. - - :return: None - """ - self.job_request_items = annotate_job_request_items( - context=context, job_request_items=self.job_request_items - ) - job_service_client = JobServiceClient(self.job_service_channel_args.channel()) - - # Health Check - health = job_service_client.health() - if health.status != jobservice_pb2.HealthCheckResponse.SERVING: - armada_logger.warn("Armada Job Service is not healthy.") - else: - armada_logger.debug("Jobservice is healthy.") - - armada_client = ArmadaClient(channel=self.armada_channel_args.channel()) - - armada_logger.debug("Submitting job(s).") - # This allows us to use a unique id from airflow - # and have all jobs in a dag correspond to same jobset - job = armada_client.submit_jobs( - queue=self.armada_queue, - job_set_id=context["run_id"], - job_request_items=self.job_request_items, - ) - - try: - job_id = job.job_response_items[0].job_id - except Exception: - raise AirflowException("Error submitting job(s) to Armada") - - armada_logger.info("Running Armada job %s with id %s", self.name, job_id) - - lookout_url = self._get_lookout_url(job_id) - if len(lookout_url) > 0: - armada_logger.info("Lookout URL: %s", lookout_url) - - # TODO: configurable timeout? - self.defer( - trigger=ArmadaJobCompleteTrigger( - job_id=job_id, - job_service_channel_args=self.job_service_channel_args.serialize(), - armada_queue=self.armada_queue, - job_set_id=context["run_id"], - airflow_task_name=self.name, - poll_interval=self.poll_interval, - ), - method_name="resume_job_complete", - kwargs={"job_id": job_id}, - ) - - def resume_job_complete(self, context, event: dict, job_id: str) -> None: - """ - Resumes this operator after deferring itself to ArmadaJobCompleteTrigger. - Only meant to be called from within Airflow. - - Reports the result of the job and returns. - - :param context: The airflow context. - :param event: The payload from the TriggerEvent raised by - ArmadaJobCompleteTrigger. - :param job_id: The job ID. - :return: None - """ - - job_state = event["job_state"] - job_message = event["job_message"] - - armada_logger.info( - "Armada Job finished with %s and message: %s", job_state, job_message - ) - airflow_error(job_state, self.name, job_id) - - def _get_lookout_url(self, job_id: str) -> str: - if self.lookout_url_template is None: - return "" - return self.lookout_url_template.replace("", job_id) - - def render_template_fields( - self, - context: Context, - jinja_env: Optional[jinja2.Environment] = None, - ) -> None: - self.job_request_items = [ - MessageToDict(x, preserving_proto_field_name=True) - for x in self.job_request_items - ] - super().render_template_fields(context, jinja_env) - self.job_request_items = [ - ParseDict(x, JobSubmitRequestItem()) for x in self.job_request_items - ] - - -class ArmadaJobCompleteTrigger(BaseTrigger): - """ - An airflow trigger that monitors the job state of an armada job. - - Triggers when the job is complete. - - :param job_id: The job ID to monitor. - :param job_service_channel_args: GRPC channel arguments to be used when - creating a grpc channel to connect to the job service instance. - :param armada_queue: The name of the armada queue. - :param job_set_id: The ID of the job set. - :param airflow_task_name: Name of the airflow task to which this trigger - belongs. - :param poll_interval: How often to poll jobservice to get status. - :return: An armada job complete trigger instance. - """ - - def __init__( - self, - job_id: str, - job_service_channel_args: GrpcChannelArgsDict, - armada_queue: str, - job_set_id: str, - airflow_task_name: str, - poll_interval: int = 30, - ) -> None: - super().__init__() - self.job_id = job_id - self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args) - self.armada_queue = armada_queue - self.job_set_id = job_set_id - self.airflow_task_name = airflow_task_name - self.poll_interval = poll_interval - - def serialize(self) -> tuple: - return ( - "armada.operators.armada_deferrable.ArmadaJobCompleteTrigger", - { - "job_id": self.job_id, - "job_service_channel_args": self.job_service_channel_args.serialize(), - "armada_queue": self.armada_queue, - "job_set_id": self.job_set_id, - "airflow_task_name": self.airflow_task_name, - "poll_interval": self.poll_interval, - }, - ) - - def __eq__(self, o): - return ( - self.task_id == o.task_id - and self.job_id == o.job_id - and self.job_service_channel_args == o.job_service_channel_args - and self.armada_queue == o.armada_queue - and self.job_set_id == o.job_set_id - and self.airflow_task_name == o.airflow_task_name - and self.poll_interval == o.poll_interval - ) - - async def run(self): - """ - Runs the trigger. Meant to be called by an airflow triggerer process. - """ - job_service_client = JobServiceAsyncIOClient( - channel=self.job_service_channel_args.aio_channel() - ) - - job_state, job_message = await search_for_job_complete_async( - armada_queue=self.armada_queue, - job_set_id=self.job_set_id, - airflow_task_name=self.airflow_task_name, - job_id=self.job_id, - job_service_client=job_service_client, - log=self.log, - poll_interval=self.poll_interval, - ) - yield TriggerEvent({"job_state": job_state, "job_message": job_message}) diff --git a/third_party/airflow/armada/operators/grpc.py b/third_party/airflow/armada/operators/grpc.py deleted file mode 100644 index 3e146ccce07..00000000000 --- a/third_party/airflow/armada/operators/grpc.py +++ /dev/null @@ -1,149 +0,0 @@ -import importlib -from typing import Optional, Sequence, Tuple, Any, TypedDict - -import grpc - - -class CredentialsCallbackDict(TypedDict): - """ - Helper class to provide stronger type checking on Credential callback args. - """ - - module_name: str - function_name: str - function_kwargs: dict - - -class GrpcChannelArgsDict(TypedDict): - """ - Helper class to provide stronger type checking on Grpc channel arugments. - """ - - target: str - options: Optional[Sequence[Tuple[str, Any]]] - compression: Optional[grpc.Compression] - credentials_callback_args: Optional[CredentialsCallbackDict] - - -class CredentialsCallback(object): - """ - Allows the use of an arbitrary callback function to get grpc credentials. - - :param module_name: The fully qualified python module name where the - function is located. - :param function_name: The name of the function to be called. - :param function_kwargs: Keyword arguments to function_name in a dictionary. - """ - - def __init__( - self, - module_name: str, - function_name: str, - function_kwargs: dict, - ) -> None: - self.module_name = module_name - self.function_name = function_name - self.function_kwargs = function_kwargs - - def call(self): - """Do the callback to get grpc credentials.""" - module = importlib.import_module(self.module_name) - func = getattr(module, self.function_name) - return func(**self.function_kwargs) - - -class GrpcChannelArguments(object): - """ - A Serializable GRPC Arguments Object. - - :param target: Target keyword argument used - when instantiating a grpc channel. - :param credentials_callback_args: Arguments to CredentialsCallback to use - when instantiating a grpc channel that takes credentials. - :param options: options keyword argument used - when instantiating a grpc channel. - :param compression: compression keyword argument used - when instantiating a grpc channel. - :return: a GrpcChannelArguments instance - """ - - def __init__( - self, - target: str, - options: Optional[Sequence[Tuple[str, Any]]] = None, - compression: Optional[grpc.Compression] = None, - credentials_callback_args: CredentialsCallbackDict = None, - ) -> None: - self.target = target - self.options = options - self.compression = compression - self.credentials_callback = None - self.credentials_callback_args = credentials_callback_args - - if credentials_callback_args is not None: - self.credentials_callback = CredentialsCallback(**credentials_callback_args) - - def __eq__(self, o): - return ( - self.target == o.target - and self.options == o.options - and self.compression == o.compression - and self.credentials_callback_args == o.credentials_callback_args - ) - - def channel(self) -> grpc.Channel: - """ - Create a grpc.Channel based on arguments supplied to this object. - - :return: Return grpc.insecure_channel if credentials is None. Otherwise - returns grpc.secure_channel. - """ - - if self.credentials_callback is None: - return grpc.insecure_channel( - target=self.target, - options=self.options, - compression=self.compression, - ) - return grpc.secure_channel( - target=self.target, - credentials=self.credentials_callback.call(), - options=self.options, - compression=self.compression, - ) - - def aio_channel(self) -> grpc.aio.Channel: - """ - Create a grpc.aio.Channel (asyncio) based on arguments supplied to this object. - - :return: Return grpc.aio.insecure_channel if credentials is None. Otherwise - returns grpc.aio.secure_channel. - """ - - if self.credentials_callback is None: - return grpc.aio.insecure_channel( - target=self.target, - options=self.options, - compression=self.compression, - ) - return grpc.aio.secure_channel( - target=self.target, - credentials=self.credentials_callback.call(), - options=self.options, - compression=self.compression, - ) - - def serialize(self) -> dict: - """ - Get a serialized version of this object. - - :return: A dict of keyword arguments used when calling - a grpc channel or instantiating this object. - """ - - return { - "target": self.target, - "credentials_callback_args": self.credentials_callback_args, - "options": self.options, - "compression": self.compression, - } diff --git a/third_party/airflow/armada/operators/jobservice.py b/third_party/airflow/armada/operators/jobservice.py deleted file mode 100644 index c6445286064..00000000000 --- a/third_party/airflow/armada/operators/jobservice.py +++ /dev/null @@ -1,97 +0,0 @@ -import json -from typing import Optional - -from armada.jobservice import jobservice_pb2_grpc, jobservice_pb2 - -import grpc -from google.protobuf import empty_pb2 - -default_jobservice_channel_options = [ - ( - "grpc.service_config", - json.dumps( - { - "methodConfig": [ - { - "name": [{"service": "jobservice.JobService"}], - "retryPolicy": { - "maxAttempts": 6 * 5, # A little under 5 minutes. - "initialBackoff": "0.1s", - "maxBackoff": "10s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - } - ] - } - ), - ) -] - - -class JobServiceClient: - """ - The JobService Client - - Implementation of gRPC stubs from JobService - - :param channel: gRPC channel used for authentication. See - https://grpc.github.io/grpc/python/grpc.html - for more information. - :return: a job service client instance - """ - - def __init__(self, channel): - self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel) - - def get_job_status( - self, queue: str, job_set_id: str, job_id: str - ) -> jobservice_pb2.JobServiceResponse: - """Get job status of a given job in a queue and job_set_id. - - Uses the GetJobStatus rpc to get a status of your job - - :param queue: The name of the queue - :param job_set_id: The name of the job set (a grouping of jobs) - :param job_id: The id of the job - :return: A Job Service Request (State, Error) - """ - job_service_request = jobservice_pb2.JobServiceRequest( - queue=queue, job_set_id=job_set_id, job_id=job_id - ) - return self.job_stub.GetJobStatus(job_service_request) - - def health(self) -> jobservice_pb2.HealthCheckResponse: - """Health Check for GRPC Request""" - return self.job_stub.Health(request=empty_pb2.Empty()) - - -def get_retryable_job_service_client( - target: str, - credentials: Optional[grpc.ChannelCredentials] = None, - compression: Optional[grpc.Compression] = None, -) -> JobServiceClient: - """ - Get a JobServiceClient that has retry configured - - :param target: grpc channel target - :param credentials: grpc channel credentials (if needed) - :param compresion: grpc channel compression - - :return: A job service client instance - """ - channel = None - if credentials is None: - channel = grpc.insecure_channel( - target=target, - options=default_jobservice_channel_options, - compression=compression, - ) - else: - channel = grpc.secure_channel( - target=target, - credentials=credentials, - options=default_jobservice_channel_options, - compression=compression, - ) - return JobServiceClient(channel) diff --git a/third_party/airflow/armada/operators/jobservice_asyncio.py b/third_party/airflow/armada/operators/jobservice_asyncio.py deleted file mode 100644 index a40b9fc14a0..00000000000 --- a/third_party/airflow/armada/operators/jobservice_asyncio.py +++ /dev/null @@ -1,80 +0,0 @@ -from armada.jobservice import ( - jobservice_pb2_grpc, - jobservice_pb2, -) -from armada.operators.jobservice import default_jobservice_channel_options - -import grpc -from typing import Optional - -from google.protobuf import empty_pb2 - - -class JobServiceAsyncIOClient: - """ - The JobService AsyncIO Client - - AsyncIO implementation of gRPC stubs from JobService - - :param channel: AsyncIO gRPC channel used for authentication. See - https://grpc.github.io/grpc/python/grpc_asyncio.html - for more information. - :return: A job service client instance - """ - - def __init__(self, channel: grpc.aio.Channel) -> None: - self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel) - - async def get_job_status( - self, queue: str, job_set_id: str, job_id: str - ) -> jobservice_pb2.JobServiceResponse: - """Get job status of a given job in a queue and job_set_id. - - Uses the GetJobStatus rpc to get a status of your job - - :param queue: The name of the queue - :param job_set_id: The name of the job set (a grouping of jobs) - :param job_id: The id of the job - :return: A Job Service Request (State, Error) - """ - job_service_request = jobservice_pb2.JobServiceRequest( - queue=queue, job_set_id=job_set_id, job_id=job_id - ) - response = await self.job_stub.GetJobStatus(job_service_request) - return response - - async def health(self) -> jobservice_pb2.HealthCheckResponse: - """Health Check for GRPC Request""" - response = await self.job_stub.Health(request=empty_pb2.Empty()) - return response - - -def get_retryable_job_service_asyncio_client( - target: str, - credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression], -) -> JobServiceAsyncIOClient: - """ - Get a JobServiceAsyncIOClient that has retry configured - - :param target: grpc channel target - :param credentials: grpc channel credentials (if needed) - :param compresion: grpc channel compression - - :return: A job service asyncio client instance - """ - channel = None - if credentials is None: - channel = grpc.aio.insecure_channel( - target=target, - options=default_jobservice_channel_options, - compression=compression, - ) - else: - channel = grpc.aio.secure_channel( - target=target, - credentials=credentials, - options=default_jobservice_channel_options, - compression=compression, - ) - return JobServiceAsyncIOClient(channel) diff --git a/third_party/airflow/armada/operators/utils.py b/third_party/airflow/armada/operators/utils.py deleted file mode 100644 index 1ab7fa35d04..00000000000 --- a/third_party/airflow/armada/operators/utils.py +++ /dev/null @@ -1,289 +0,0 @@ -import asyncio -import logging -import os -import time - -from airflow.exceptions import AirflowException -from typing import List, Optional, Tuple -from enum import Enum - -from armada.operators.jobservice import JobServiceClient -from armada.operators.jobservice_asyncio import JobServiceAsyncIOClient -from armada.jobservice import jobservice_pb2 -from armada_client.armada import submit_pb2 - - -class JobState(Enum): - SUBMITTED = 0 - DUPLICATE_FOUND = 1 - RUNNING = 2 - FAILED = 3 - SUCCEEDED = 4 - CANCELLED = 5 - JOB_ID_NOT_FOUND = 6 - CONNECTION_ERR = 7 - - -_pb_to_job_state = { - jobservice_pb2.JobServiceResponse.SUBMITTED: JobState.SUBMITTED, - jobservice_pb2.JobServiceResponse.DUPLICATE_FOUND: JobState.DUPLICATE_FOUND, - jobservice_pb2.JobServiceResponse.RUNNING: JobState.RUNNING, - jobservice_pb2.JobServiceResponse.FAILED: JobState.FAILED, - jobservice_pb2.JobServiceResponse.SUCCEEDED: JobState.SUCCEEDED, - jobservice_pb2.JobServiceResponse.CANCELLED: JobState.CANCELLED, - jobservice_pb2.JobServiceResponse.JOB_ID_NOT_FOUND: JobState.JOB_ID_NOT_FOUND, - # NOTE(Clif): For whatever reason CONNECTION_ERR is not present in the - # generated protobuf. - 7: JobState.CONNECTION_ERR, -} - - -def job_state_from_pb(state) -> JobState: - return _pb_to_job_state[state] - - -def airflow_error(job_state: JobState, name: str, job_id: str): - """Throw an error on a terminal event if job errored out - - :param job_state: A JobState enum class - :param name: The name of your armada job - :param job_id: The job id that armada assigns to it - :return: No Return or an AirflowFailException. - - AirflowFailException tells Airflow Schedule to not reschedule the task - - """ - if job_state == JobState.SUCCEEDED: - return - if ( - job_state == JobState.FAILED - or job_state == JobState.CANCELLED - or job_state == JobState.JOB_ID_NOT_FOUND - ): - job_message = job_state.name - # AirflowException allows operator-level retries. AirflowFailException - # does *not*. - raise AirflowException(f"The Armada job {name}:{job_id} {job_message}") - - -def default_job_status_callable( - armada_queue: str, - job_set_id: str, - job_id: str, - job_service_client: JobServiceClient, -) -> jobservice_pb2.JobServiceResponse: - return job_service_client.get_job_status( - queue=armada_queue, job_id=job_id, job_set_id=job_set_id - ) - - -armada_logger = logging.getLogger("airflow.task") - - -def search_for_job_complete( - armada_queue: str, - job_set_id: str, - airflow_task_name: str, - job_id: str, - poll_interval: int = 30, - job_service_client: Optional[JobServiceClient] = None, - job_status_callable=default_job_status_callable, - time_out_for_failure: int = 7200, -) -> Tuple[JobState, str]: - """ - - Poll JobService cache until you get a terminated event. - - A terminated event is SUCCEEDED, FAILED or CANCELLED - - :param armada_queue: The queue for armada - :param job_set_id: Your job_set_id - :param airflow_task_name: The name of your armada job - :param poll_interval: Polling interval for jobservice to get status. - :param job_id: The name of the job id that armada assigns to it - :param job_service_client: A JobServiceClient that is used for polling. - It is optional only for testing - :param job_status_callable: A callable object for test injection. - :param time_out_for_failure: The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - :return: A tuple of JobStateEnum, message - """ - start_time = time.time() - # Overwrite time_out_for_failure by environment variable for configuration - armada_time_out_env = os.getenv("ARMADA_AIRFLOW_TIME_OUT_JOB_ID") - if armada_time_out_env: - time_out_for_failure = int(armada_time_out_env) - while True: - # The else statement is for testing purposes. - # We want to allow a test callable to be passed - if job_service_client: - job_status_return = job_status_callable( - armada_queue=armada_queue, - job_id=job_id, - job_set_id=job_set_id, - job_service_client=job_service_client, - ) - else: - job_status_return = job_status_callable( - armada_queue=armada_queue, job_id=job_id, job_set_id=job_set_id - ) - - job_state = job_state_from_pb(job_status_return.state) - armada_logger.debug(f"Got job state '{job_state.name}' for job {job_id}") - - time.sleep(poll_interval) - if job_state == JobState.SUCCEEDED: - job_message = f"Armada {airflow_task_name}:{job_id} succeeded" - break - if job_state == JobState.FAILED: - job_message = ( - f"Armada {airflow_task_name}:{job_id} failed\n" - f"failed with reason {job_status_return.error}" - ) - break - if job_state == JobState.CANCELLED: - job_message = f"Armada {airflow_task_name}:{job_id} cancelled" - break - if job_state == JobState.CONNECTION_ERR: - log_messages = ( - f"Armada {airflow_task_name}:{job_id} connection error (will retry)" - f"failed with reason {job_status_return.error}" - ) - armada_logger.warning(log_messages) - continue - - if job_state == JobState.JOB_ID_NOT_FOUND: - end_time = time.time() - time_elasped = int(end_time) - int(start_time) - if time_elasped > time_out_for_failure: - job_state = JobState.JOB_ID_NOT_FOUND - job_message = ( - f"Armada {airflow_task_name}:{job_id} could not find a job id and\n" - f"hit a timeout" - ) - break - - return job_state, job_message - - -def annotate_job_request_items( - context, job_request_items: List[submit_pb2.JobSubmitRequestItem] -) -> List[submit_pb2.JobSubmitRequestItem]: - """ - Annotates the inbound job request items with Airflow context elements - - :param context: The airflow context. - - :param job_request_items: The job request items to be sent to armada - - :return: annotated job request items for armada - """ - task_instance = context["ti"] - task_id = task_instance.task_id - run_id = context["run_id"] - dag_id = context["dag"].dag_id - - for item in job_request_items: - item.annotations[get_annotation_key_prefix() + "taskId"] = task_id - item.annotations[get_annotation_key_prefix() + "taskRunId"] = run_id - item.annotations[get_annotation_key_prefix() + "dagId"] = dag_id - - return job_request_items - - -ANNOTATION_KEY_PREFIX = "armadaproject.io/" - - -def get_annotation_key_prefix() -> str: - """ - Provides the annotation key prefix, - which can be specified in env var ANNOTATION_KEY_PREFIX. - A default is provided if the env var is not defined - - :return: string annotation key prefix - """ - env_var_name = "ANNOTATION_KEY_PREFIX" - if env_var_name in os.environ: - return f"{os.environ.get(env_var_name)}" - else: - return ANNOTATION_KEY_PREFIX - - -async def search_for_job_complete_async( - armada_queue: str, - job_set_id: str, - airflow_task_name: str, - job_id: str, - job_service_client: JobServiceAsyncIOClient, - log, - poll_interval: int, - time_out_for_failure: int = 7200, -) -> Tuple[JobState, str]: - """ - - Poll JobService cache asyncronously until you get a terminated event. - - A terminated event is SUCCEEDED, FAILED or CANCELLED - - :param armada_queue: The queue for armada - :param job_set_id: Your job_set_id - :param airflow_task_name: The name of your armada job - :param job_id: The name of the job id that armada assigns to it - :param job_service_client: A JobServiceClient that is used for polling. - It is optional only for testing - :param poll_interval: How often to poll jobservice to get status. - :param time_out_for_failure: The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - :return: A tuple of JobStateEnum, message - """ - start_time = time.time() - # Overwrite time_out_for_failure by environment variable for configuration - armada_time_out_env = os.getenv("ARMADA_AIRFLOW_TIME_OUT_JOB_ID") - if armada_time_out_env: - time_out_for_failure = int(armada_time_out_env) - while True: - job_status_return = await job_service_client.get_job_status( - queue=armada_queue, - job_id=job_id, - job_set_id=job_set_id, - ) - - job_state = job_state_from_pb(job_status_return.state) - log.debug(f"Got job state '{job_state.name}' for job {job_id}") - - await asyncio.sleep(poll_interval) - - if job_state == JobState.SUCCEEDED: - job_message = f"Armada {airflow_task_name}:{job_id} succeeded" - break - if job_state == JobState.FAILED: - job_message = ( - f"Armada {airflow_task_name}:{job_id} failed\n" - f"failed with reason {job_status_return.error}" - ) - break - if job_state == JobState.CANCELLED: - job_message = f"Armada {airflow_task_name}:{job_id} cancelled" - break - if job_state == JobState.CONNECTION_ERR: - log_messages = ( - f"Armada {airflow_task_name}:{job_id} connection error (will retry)" - f"failed with reason {job_status_return.error}" - ) - log.warning(log_messages) - continue - - if job_state == JobState.JOB_ID_NOT_FOUND: - end_time = time.time() - time_elasped = int(end_time) - int(start_time) - if time_elasped > time_out_for_failure: - job_state = JobState.JOB_ID_NOT_FOUND - job_message = ( - f"Armada {airflow_task_name}:{job_id} could not find a job id and\n" - f"hit a timeout" - ) - break - - return job_state, job_message diff --git a/third_party/airflow/armada/provider.yaml b/third_party/airflow/armada/provider.yaml deleted file mode 100644 index ff95bce9210..00000000000 --- a/third_party/airflow/armada/provider.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - ---- -package-name: apache-airflow-providers-armada -name: Armada -description: | - `ArmadaOperator `__ -versions: - -0.3.14 - -additional-dependencies: - - apache-airflow>=2.2.0 - -integrations: - - integration-name: Armada - external-doc-url: https://armadaproject.io/ - logo: /integration-logos/armada/armada.png - tags: [software] - -operators: - - integration-name: Armada - python-modules: - - airflow.providers.armada.operators.armada \ No newline at end of file diff --git a/third_party/airflow/armada/triggers/armada.py b/third_party/airflow/armada/triggers/armada.py new file mode 100644 index 00000000000..284fe305169 --- /dev/null +++ b/third_party/airflow/armada/triggers/armada.py @@ -0,0 +1,269 @@ +import asyncio +import importlib +import time +from functools import cached_property +from typing import AsyncIterator, Any, Optional, Tuple, Dict + +from airflow.triggers.base import BaseTrigger, TriggerEvent +from armada_client.armada.job_pb2 import JobRunDetails +from armada_client.typings import JobState + +from armada_client.asyncio_client import ArmadaAsyncIOClient +from armada.auth import TokenRetriever +from armada.logs.pod_log_manager import PodLogManagerAsync +from armada.model import GrpcChannelArgs +from pendulum import DateTime + + +class ArmadaTrigger(BaseTrigger): + """ + An Airflow Trigger that can asynchronously manage an Armada job. + """ + + def __init__( + self, + job_id: str, + armada_queue: str, + job_set_id: str, + poll_interval: int, + tracking_message: str, + job_acknowledgement_timeout: int, + job_request_namespace: str, + channel_args: GrpcChannelArgs = None, + channel_args_details: Dict[str, Any] = None, + container_logs: Optional[str] = None, + k8s_token_retriever: Optional[TokenRetriever] = None, + k8s_token_retriever_details: Optional[Tuple[str, Dict[str, Any]]] = None, + last_log_time: Optional[DateTime] = None, + ): + """ + Initializes an instance of ArmadaTrigger, which is an Airflow trigger for + managing Armada jobs asynchronously. + + :param job_id: The unique identifier of the job to be monitored. + :type job_id: str + :param armada_queue: The Armada queue under which the job was submitted. + Required for job cancellation. + :type armada_queue: str + :param job_set_id: The unique identifier of the job set under which the job + was submitted. Required for job cancellation. + :type job_set_id: str + :param poll_interval: The interval, in seconds, at which the job status will be + checked. + :type poll_interval: int + :param tracking_message: A message to log or display for tracking the job + status. + :type tracking_message: str + :param job_acknowledgement_timeout: The timeout, in seconds, to wait for the job + to be acknowledged by Armada. + :type job_acknowledgement_timeout: int + :param job_request_namespace: The Kubernetes namespace under which the job was + submitted. + :type job_request_namespace: str + :param channel_args: The arguments to configure the gRPC channel. If None, + default arguments will be used. + :type channel_args: GrpcChannelArgs, optional + :param channel_args_details: Additional details or configurations for the gRPC + channel as a dictionary. Only used when + the trigger is rehydrated after serialization. + :type channel_args_details: dict[str, Any], optional + :param container_logs: Name of container from which to retrieve logs + :type container_logs: str, optional + :param k8s_token_retriever: An optional instance of type TokenRetriever, used to + refresh the Kubernetes auth token + :type k8s_token_retriever: TokenRetriever, optional + :param k8s_token_retriever_details: Configuration for TokenRetriever as a + dictionary. + Only used when the trigger is + rehydrated after serialization. + :type k8s_token_retriever_details: Tuple[str, Dict[str, Any]], optional + :param last_log_time: where to resume logs from + :type last_log_time: DateTime, optional + """ + super().__init__() + self.job_id = job_id + self.armada_queue = armada_queue + self.job_set_id = job_set_id + self.poll_interval = poll_interval + self.tracking_message = tracking_message + self.job_acknowledgement_timeout = job_acknowledgement_timeout + self.container_logs = container_logs + self.last_log_time = last_log_time + self.job_request_namespace = job_request_namespace + self._pod_manager = None + self.k8s_token_retriever = k8s_token_retriever + + if channel_args: + self.channel_args = channel_args + elif channel_args_details: + self.channel_args = GrpcChannelArgs(**channel_args_details) + else: + raise f"must provide either {channel_args} or {channel_args_details}" + + if k8s_token_retriever_details: + classpath, kwargs = k8s_token_retriever_details + module_path, class_name = classpath.rsplit( + ".", 1 + ) # Split the classpath to module and class name + module = importlib.import_module( + module_path + ) # Dynamically import the module + cls = getattr(module, class_name) # Get the class from the module + self.k8s_token_retriever = cls( + **kwargs + ) # Instantiate the class with the deserialized kwargs + + def serialize(self) -> tuple: + """ + Serialises the state of this Trigger. + When the Trigger is re-hydrated, these values will be passed to init() as kwargs + :return: + """ + k8s_token_retriever_details = ( + self.k8s_token_retriever.serialize() if self.k8s_token_retriever else None + ) + return ( + "armada.triggers.armada.ArmadaTrigger", + { + "job_id": self.job_id, + "armada_queue": self.armada_queue, + "job_set_id": self.job_set_id, + "channel_args_details": self.channel_args.serialize(), + "poll_interval": self.poll_interval, + "tracking_message": self.tracking_message, + "job_acknowledgement_timeout": self.job_acknowledgement_timeout, + "container_logs": self.container_logs, + "k8s_token_retriever_details": k8s_token_retriever_details, + "last_log_time": self.last_log_time, + "job_request_namespace": self.job_request_namespace, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Run the Trigger Asynchronously. This will poll Armada until the Job reaches a + terminal state + """ + try: + response = await self._poll_for_termination(self.job_id) + yield TriggerEvent(response) + except Exception as exc: + yield TriggerEvent( + { + "status": "error", + "job_id": self.job_id, + "response": f"Job {self.job_id} did not succeed. Error was {exc}", + } + ) + + """Cannot call on_kill from trigger, will asynchronously cancel jobs instead.""" + + async def _cancel_job(self) -> None: + try: + result = await self.client.cancel_jobs( + queue=self.armada_queue, + job_set_id=self.job_set_id, + job_id=self.job_id, + ) + if len(list(result.cancelled_ids)) > 0: + self.log.info(f"Cancelled job with id {result.cancelled_ids}") + else: + self.log.warning(f"Failed to cancel job with id {self.job_id}") + except Exception as e: + self.log.warning(f"Failed to cancel job with id {self.job_id}: {e}") + + async def _poll_for_termination(self, job_id: str) -> Dict[str, Any]: + state = JobState.UNKNOWN + start_time = time.time() + job_acknowledged = False + run_details = None + + # Poll for terminal state + while state.is_active(): + resp = await self.client.get_job_status([job_id]) + state = JobState(resp.job_states[job_id]) + self.log.info( + f"Job {job_id} is in state: {state.name}. {self.tracking_message}" + ) + + if state != JobState.UNKNOWN: + job_acknowledged = True + + if ( + not job_acknowledged + and int(time.time() - start_time) > self.job_acknowledgement_timeout + ): + await self._cancel_job() + return { + "status": "error", + "job_id": job_id, + "response": f"Job {job_id} not acknowledged within timeout " + f"{self.job_acknowledgement_timeout}.", + } + + if self.container_logs and not run_details: + if state == JobState.RUNNING or state.is_terminal(): + run_details = await self._get_latest_job_run_details(self.job_id) + + if run_details: + try: + log_status = await self.pod_manager( + run_details.cluster + ).fetch_container_logs( + pod_name=f"armada-{self.job_id}-0", + namespace=self.job_request_namespace, + container_name=self.container_logs, + since_time=self.last_log_time, + ) + self.last_log_time = log_status.last_log_time + except Exception as e: + self.log.exception(e) + + if state.is_active(): + self.log.debug(f"Sleeping for {self.poll_interval} seconds") + await asyncio.sleep(self.poll_interval) + + self.log.info(f"Job {job_id} terminated with state:{state.name}") + if state != JobState.SUCCEEDED: + return { + "status": "error", + "job_id": job_id, + "response": f"Job {job_id} did not succeed. Final status was " + f"{state.name}", + } + return { + "status": "success", + "job_id": job_id, + "response": f"Job {job_id} succeeded", + } + + @cached_property + def client(self) -> ArmadaAsyncIOClient: + return ArmadaAsyncIOClient(channel=self.channel_args.aio_channel()) + + def pod_manager(self, k8s_context: str) -> PodLogManagerAsync: + if self._pod_manager is None: + self._pod_manager = PodLogManagerAsync( + k8s_context=k8s_context, token_retriever=self.k8s_token_retriever + ) + + return self._pod_manager + + async def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: + resp = await self.client.get_job_details([job_id]) + job_details = resp.job_details[job_id] + if job_details and job_details.latest_run_id: + for run in job_details.job_runs: + if run.run_id == job_details.latest_run_id: + return run + return None + + def __eq__(self, other): + if not isinstance(other, ArmadaTrigger): + return False + return ( + self.job_id == other.job_id + and self.channel_args.serialize() == other.channel_args.serialize() + and self.poll_interval == other.poll_interval + and self.tracking_message == other.tracking_message + ) diff --git a/third_party/airflow/examples/bad_armada.py b/third_party/airflow/examples/bad_armada.py index 8474eae0351..11bf545691e 100644 --- a/third_party/airflow/examples/bad_armada.py +++ b/third_party/airflow/examples/bad_armada.py @@ -1,5 +1,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator + +from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -56,48 +58,44 @@ def submit_sleep_container(image: str): ) as dag: """ This Airflow DAG follows a similar pattern: - 1) Define arguments for armada and jobservice grpc channels. + 1) Define arguments for armada grpc channel. 2) Define your ArmadaOperator tasks that you want to run. 3) Generate a DAG definition. """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") op = BashOperator(task_id="dummy", bash_command="echo Hello World!") armada = ArmadaOperator( task_id="armada", name="armada", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_container(image="busybox"), + channel_args=armada_channel_args, + job_request=submit_sleep_container(image="busybox")[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) """ - This task is used to verify that if an Armada Job - fails we are correctly telling Airflow that it failed. - """ + This task is used to verify that if an Armada Job + fails we are correctly telling Airflow that it failed. + """ bad_armada = ArmadaOperator( task_id="armada_fail", name="armada_fail", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_container(image="nonexistant"), + channel_args=armada_channel_args, + job_request=submit_sleep_container(image="busybox")[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) good_armada = ArmadaOperator( task_id="good_armada", name="good_armada", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_container(image="busybox"), + channel_args=armada_channel_args, + job_request=submit_sleep_container(image="busybox")[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) """ - Airflow syntax to say - Run op first and then run armada and bad_armada in parallel - If all jobs are successful, run good_armada. - """ + Airflow syntax to say + Run op first and then run armada and bad_armada in parallel + If all jobs are successful, run good_armada. + """ op >> [armada, bad_armada] >> good_armada diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index dc64cdc76b2..5979e391f0b 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -1,5 +1,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator + +from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -62,11 +64,9 @@ def submit_sleep_job(): default_args={"retries": 2}, ) as dag: """ - The ArmadaDeferrableOperator requires grpc.channel arguments for armada and - the jobservice. + The ArmadaDeferrableOperator requires grpc.channel arguments for armada. """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") """ This defines an Airflow task that runs Hello World and it gives the airflow @@ -75,8 +75,7 @@ def submit_sleep_job(): op = BashOperator(task_id="dummy", bash_command="echo Hello World!") """ This is creating an Armada task with the task_id of armada and name of armada. - The Airflow operator needs queue and job-set for Armada - You also specify the PythonClient and JobServiceClient for each task. + The Airflow operator needs a queue for Armada You should reuse them for all your tasks. This job will use the podspec defined above. """ @@ -84,9 +83,8 @@ def submit_sleep_job(): task_id="armada1", name="armada1", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -94,9 +92,8 @@ def submit_sleep_job(): task_id="armada2", name="armada2", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -104,9 +101,8 @@ def submit_sleep_job(): task_id="armada3", name="armada3", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -114,9 +110,8 @@ def submit_sleep_job(): task_id="armada4", name="armada4", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -124,9 +119,8 @@ def submit_sleep_job(): task_id="armada5", name="armada5", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -134,9 +128,8 @@ def submit_sleep_job(): task_id="armada6", name="armada6", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -144,9 +137,8 @@ def submit_sleep_job(): task_id="armada7", name="armada7", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -154,9 +146,8 @@ def submit_sleep_job(): task_id="armada8", name="armada8", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -164,9 +155,8 @@ def submit_sleep_job(): task_id="armada9", name="armada9", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -174,9 +164,8 @@ def submit_sleep_job(): task_id="armada10", name="armada10", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -184,9 +173,8 @@ def submit_sleep_job(): task_id="armada11", name="armada11", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -194,18 +182,16 @@ def submit_sleep_job(): task_id="armada12", name="armada12", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) armada13 = ArmadaOperator( task_id="armada13", name="armada13", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -213,9 +199,8 @@ def submit_sleep_job(): task_id="armada14", name="armada14", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -223,9 +208,8 @@ def submit_sleep_job(): task_id="armada15", name="armada15", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -233,9 +217,8 @@ def submit_sleep_job(): task_id="armada16", name="armada16", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -243,9 +226,8 @@ def submit_sleep_job(): task_id="armada17", name="armada17", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -253,18 +235,16 @@ def submit_sleep_job(): task_id="armada18", name="armada18", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) armada19 = ArmadaOperator( task_id="armada19", name="armada19", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -272,9 +252,8 @@ def submit_sleep_job(): task_id="armada20", name="armada20", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) diff --git a/third_party/airflow/examples/hello_armada.py b/third_party/airflow/examples/hello_armada.py index 53c20c78038..0f59932d96c 100644 --- a/third_party/airflow/examples/hello_armada.py +++ b/third_party/airflow/examples/hello_armada.py @@ -1,5 +1,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator + +from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -62,11 +64,9 @@ def submit_sleep_job(): default_args={"retries": 2}, ) as dag: """ - The ArmadaOperator requires grpc.channel arguments for armada and - the jobservice. + The ArmadaOperator requires grpc.channel arguments for armada. """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") """ This defines an Airflow task that runs Hello World and it gives the airflow @@ -75,8 +75,7 @@ def submit_sleep_job(): op = BashOperator(task_id="dummy", bash_command="echo Hello World!") """ This is creating an Armada task with the task_id of armada and name of armada. - The Airflow operator needs queue and job-set for Armada - You also specify the PythonClient and JobServiceClient for each task. + The Airflow operator needs queue for Armada You should reuse them for all your tasks. This job will use the podspec defined above. """ @@ -84,9 +83,8 @@ def submit_sleep_job(): task_id="armada", name="armada", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) """ diff --git a/third_party/airflow/examples/hello_armada_deferrable.py b/third_party/airflow/examples/hello_armada_deferrable.py index 907242e4932..eb028d61a40 100644 --- a/third_party/airflow/examples/hello_armada_deferrable.py +++ b/third_party/airflow/examples/hello_armada_deferrable.py @@ -1,6 +1,5 @@ from airflow import DAG from airflow.operators.bash import BashOperator -from armada.operators.armada_deferrable import ArmadaDeferrableOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 from armada_client.k8s.io.apimachinery.pkg.api.resource import ( @@ -13,6 +12,9 @@ import pendulum +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator + def submit_sleep_job(): """ @@ -63,12 +65,9 @@ def submit_sleep_job(): default_args={"retries": 2}, ) as dag: """ - The ArmadaDeferrableOperatorOperator requires grpc.aio.channel arguments + The ArmadaOperator requires GrpcChannelArgs arguments """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = { - "target": "127.0.0.1:60003", - } + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") """ This defines an Airflow task that runs Hello World and it gives the airflow task name of dummy. @@ -76,19 +75,17 @@ def submit_sleep_job(): op = BashOperator(task_id="dummy", bash_command="echo Hello World!") """ This is creating an Armada task with the task_id of armada and name of armada. - The Airflow operator needs queue and job-set for Armada - You also specify the PythonClient and JobServiceClient channel arguments - for each task. + The Airflow operator needs queue for Armada. This job will use the podspec defined above. """ - armada = ArmadaDeferrableOperator( + armada = ArmadaOperator( task_id="armada_deferrable", name="armada_deferrable", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, + channel_args=armada_channel_args, armada_queue="test", - job_request_items=submit_sleep_job(), + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", + deferrable=True, ) """ Airflow dag syntax for running op and then armada. diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index 8a0ffec1c0c..bde16313944 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -1,30 +1,67 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + [project] name = "armada_airflow" -version = "0.5.6" +version = "1.0.0" description = "Armada Airflow Operator" -requires-python = ">=3.7" -# Note(JayF): This dependency value is not suitable for release. Whatever -# release automation we create will have to change this to a dep on a pypi -# package, but we can't do that now because it would make development -# extremely difficult. -dependencies = [ - "armada-client", - "apache-airflow>=2.6.3", - "grpcio==1.58.0", - "grpcio-tools==1.58.0", - "types-protobuf==4.24.0.1", - "protobuf>=3.20,<5.0" -] +readme='README.md' authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] license = { text = "Apache Software License" } -readme = "README.md" +dependencies=[ + 'armada-client==0.3.4', + 'apache-airflow>=2.6.3', + 'grpcio==1.58.0', + 'grpcio-tools==1.58.0', + 'types-protobuf==4.24.0.1', + 'kubernetes>=23.6.0', + 'kubernetes_asyncio>=24.2.3', +] +requires-python=">=3.8" +classifiers=[ + 'Programming Language :: Python :: 3', + 'Operating System :: OS Independent', +] [project.optional-dependencies] -format = ["black==23.7.0", "flake8==7.0.0", "pylint==2.17.5"] -test = ["pytest==7.3.1", "coverage==7.3.2", "pytest-asyncio==0.21.1"] +format = ["black>=24.0.0", "flake8==7.0.0", "pylint==2.17.5"] +test = ["pytest==7.3.1", "coverage==7.3.2", "pytest-asyncio==0.21.1", + "pytest-mock>=3.14.0"] # note(JayF): sphinx-jekyll-builder was broken by sphinx-markdown-builder 0.6 -- so pin to 0.5.5 docs = ["sphinx==7.1.2", "sphinx-jekyll-builder==0.3.0", "sphinx-toolbox==3.2.0b1", "sphinx-markdown-builder==0.5.5"] -[build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +[project.urls] +repository='https://github.com/armadaproject/armada' + +[tools.setuptools.packages.find] +include = ["armada_airflow*"] + +[tool.black] +line-length = 88 +target-version = ['py310'] +include = ''' +/( + armada + | test +)/ +''' +exclude = ''' +/( + \.git + | venv + | build + | dist + | new + | .tox + | docs + | armada_airflow.egg-info + | __pycache__* +)/ +''' + +[tool.flake8] +# These settings are reccomended by upstream black to make flake8 find black +# style formatting correct. +max-line-length = 88 +extend-ignore = "E203" diff --git a/third_party/airflow/armada/__init__.py b/third_party/airflow/test/__init__.py similarity index 100% rename from third_party/airflow/armada/__init__.py rename to third_party/airflow/test/__init__.py diff --git a/third_party/airflow/test/integration/test_airflow_operator_logic.py b/third_party/airflow/test/integration/test_airflow_operator_logic.py new file mode 100644 index 00000000000..c2931715f70 --- /dev/null +++ b/third_party/airflow/test/integration/test_airflow_operator_logic.py @@ -0,0 +1,232 @@ +import os +import uuid +from unittest.mock import MagicMock + +import pytest +import threading + +from airflow.exceptions import AirflowException +from armada_client.typings import JobState +from armada_client.armada import ( + submit_pb2, +) +from armada_client.client import ArmadaClient +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) +import grpc +from typing import Any + +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator + +DEFAULT_TASK_ID = "test_task_1" +DEFAULT_DAG_ID = "test_dag_1" +DEFAULT_RUN_ID = "test_run_1" +DEFAULT_QUEUE = "queue-a" +DEFAULT_NAMESPACE = "personal-anonymous" +DEFAULT_POLLING_INTERVAL = 1 +DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 10 + + +@pytest.fixture(scope="function", name="context") +def default_context() -> Any: + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_ti.xcom_pull.return_value = None + mock_ti.xcom_push.return_value = None + mock_dag = MagicMock() + mock_dag.dag_id = DEFAULT_DAG_ID + return { + "ti": mock_ti, + "run_id": DEFAULT_RUN_ID, + "dag": mock_dag, + } + + +@pytest.fixture(scope="session", name="channel_args") +def queryapi_channel_args() -> GrpcChannelArgs: + server_name = os.environ.get("ARMADA_SERVER", "localhost") + server_port = os.environ.get("ARMADA_PORT", "50051") + + return GrpcChannelArgs(target=f"{server_name}:{server_port}") + + +@pytest.fixture(scope="session", name="client") +def no_auth_client() -> ArmadaClient: + server_name = os.environ.get("ARMADA_SERVER", "localhost") + server_port = os.environ.get("ARMADA_PORT", "50051") + + return ArmadaClient(channel=grpc.insecure_channel(f"{server_name}:{server_port}")) + + +def sleep_pod(image: str): + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="goodsleep", + image=image, + args=["sleep", "5s"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="0.2"), + "memory": api_resource.Quantity(string="64Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="0.2"), + "memory": api_resource.Quantity(string="64Mi"), + }, + ), + ) + ], + ) + return [ + submit_pb2.JobSubmitRequestItem( + priority=1, pod_spec=pod, namespace=DEFAULT_NAMESPACE + ) + ] + + +def test_success_job( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + job_set_name = f"test-{uuid.uuid1()}" + job = client.submit_jobs( + queue=DEFAULT_QUEUE, + job_set_id=job_set_name, + job_request_items=sleep_pod(image="busybox"), + ) + job_id = job.job_response_items[0].job_id + + mocker.patch( + "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", + return_value=job_id, + ) + + operator = ArmadaOperator( + task_id=DEFAULT_TASK_ID, + name="test_job_success", + channel_args=channel_args, + armada_queue=DEFAULT_QUEUE, + job_request=sleep_pod(image="busybox")[0], + poll_interval=DEFAULT_POLLING_INTERVAL, + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + ) + + operator.execute(context) + + response = operator.client.get_job_status([job_id]) + assert JobState(response.job_states[job_id]) == JobState.SUCCEEDED + + +def test_bad_job( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + job_set_name = f"test-{uuid.uuid1()}" + job = client.submit_jobs( + queue=DEFAULT_QUEUE, + job_set_id=job_set_name, + job_request_items=sleep_pod(image="NOTACONTAINER"), + ) + job_id = job.job_response_items[0].job_id + + mocker.patch( + "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", + return_value=job_id, + ) + + operator = ArmadaOperator( + task_id=DEFAULT_TASK_ID, + name="test_job_failure", + channel_args=channel_args, + armada_queue=DEFAULT_QUEUE, + job_request=sleep_pod(image="busybox")[0], + poll_interval=DEFAULT_POLLING_INTERVAL, + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + ) + + try: + operator.execute(context) + pytest.fail( + "Operator did not raise AirflowException on job failure as expected" + ) + except AirflowException: # Expected + response = operator.client.get_job_status([job_id]) + assert JobState(response.job_states[job_id]) == JobState.FAILED + except Exception as e: + pytest.fail( + "Operator did not raise AirflowException on job failure as expected, " + f"raised {e} instead" + ) + + +def success_job( + task_number: int, context: Any, channel_args: GrpcChannelArgs +) -> JobState: + operator = ArmadaOperator( + task_id=f"{DEFAULT_TASK_ID}_{task_number}", + name="test_job_success", + channel_args=channel_args, + armada_queue=DEFAULT_QUEUE, + job_request=sleep_pod(image="busybox")[0], + poll_interval=DEFAULT_POLLING_INTERVAL, + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + ) + + operator.execute(context) + + response = operator.client.get_job_status([operator.job_id]) + return JobState(response.job_states[operator.job_id]) + + +@pytest.mark.skip(reason="we should not test performance in the CI.") +def test_parallel_execution( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + threads = [] + success_job(task_number=0, context=context, channel_args=channel_args) + for task_number in range(5): + t = threading.Thread( + target=success_job, args=[task_number, context, channel_args] + ) + t.start() + threads.append(t) + + for thread in threads: + thread.join() + + +@pytest.mark.skip(reason="we should not test performance in the CI.") +def test_parallel_execution_large( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + threads = [] + success_job(task_number=0, context=context, channel_args=channel_args) + for task_number in range(80): + t = threading.Thread( + target=success_job, args=[task_number, context, channel_args] + ) + t.start() + threads.append(t) + + for thread in threads: + thread.join() + + +@pytest.mark.skip(reason="we should not test performance in the CI.") +def test_parallel_execution_huge( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + threads = [] + success_job(task_number=0, context=context, channel_args=channel_args) + for task_number in range(500): + t = threading.Thread( + target=success_job, args=[task_number, context, channel_args] + ) + t.start() + threads.append(t) + + for thread in threads: + thread.join() diff --git a/third_party/airflow/armada/jobservice/__init__.py b/third_party/airflow/test/operators/__init__.py similarity index 100% rename from third_party/airflow/armada/jobservice/__init__.py rename to third_party/airflow/test/operators/__init__.py diff --git a/third_party/airflow/test/operators/test_armada.py b/third_party/airflow/test/operators/test_armada.py new file mode 100644 index 00000000000..1f134ce3411 --- /dev/null +++ b/third_party/airflow/test/operators/test_armada.py @@ -0,0 +1,310 @@ +import unittest +from math import ceil +from unittest.mock import MagicMock, patch, PropertyMock + +from airflow.exceptions import AirflowException +from armada_client.armada import submit_pb2, job_pb2 +from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) + +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator +from armada.triggers.armada import ArmadaTrigger + +DEFAULT_JOB_ID = "test_job" +DEFAULT_TASK_ID = "test_task_1" +DEFAULT_DAG_ID = "test_dag_1" +DEFAULT_RUN_ID = "test_run_1" +DEFAULT_QUEUE = "test_queue_1" +DEFAULT_POLLING_INTERVAL = 30 +DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 + + +class TestArmadaOperator(unittest.TestCase): + def setUp(self): + # Set up a mock context + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_dag = MagicMock() + mock_dag.dag_id = DEFAULT_DAG_ID + self.context = { + "ti": mock_ti, + "run_id": DEFAULT_RUN_ID, + "dag": mock_dag, + } + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_execute(self, mock_client_fn, _): + test_cases = [ + { + "name": "Job Succeeds", + "statuses": [submit_pb2.RUNNING, submit_pb2.SUCCEEDED], + "success": True, + }, + { + "name": "Job Failed", + "statuses": [submit_pb2.RUNNING, submit_pb2.FAILED], + "success": False, + }, + { + "name": "Job cancelled", + "statuses": [submit_pb2.RUNNING, submit_pb2.CANCELLED], + "success": False, + }, + { + "name": "Job preempted", + "statuses": [submit_pb2.RUNNING, submit_pb2.PREEMPTED], + "success": False, + }, + { + "name": "Job Succeeds but takes a lot of transitions", + "statuses": [ + submit_pb2.SUBMITTED, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.SUCCEEDED, + ], + "success": True, + }, + ] + + for test_case in test_cases: + with self.subTest(test_case=test_case["name"]): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[ + submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID) + ] + ) + + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in test_case["statuses"] + ] + + mock_client_fn.return_value = mock_client + self.context["ti"].xcom_pull.return_value = None + + try: + operator.execute(self.context) + self.assertTrue(test_case["success"]) + except AirflowException: + self.assertFalse(test_case["success"]) + return + + self.assertEqual(mock_client.submit_jobs.call_count, 1) + self.assertEqual( + mock_client.get_job_status.call_count, len(test_case["statuses"]) + ) + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.on_kill", new_callable=PropertyMock) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_unacknowledged_results_in_on_kill(self, mock_client_fn, mock_on_kill, _): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=False, + job_acknowledgement_timeout=-1, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] + ) + mock_client_fn.return_value = mock_client + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [submit_pb2.UNKNOWN, submit_pb2.UNKNOWN] + ] + + self.context["ti"].xcom_pull.return_value = None + operator.execute(self.context) + self.assertEqual(mock_on_kill.call_count, 1) + + """We call on_kill by triggering the job unacknowledged timeout""" + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_on_kill_cancels_job(self, mock_client_fn, _): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=False, + job_acknowledgement_timeout=-1, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] + ) + mock_client_fn.return_value = mock_client + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [ + submit_pb2.UNKNOWN + for _ in range( + 1 + + ceil( + DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL + ) + ) + ] + ] + + self.context["ti"].xcom_pull.return_value = None + operator.execute(self.context) + self.assertEqual(mock_client.cancel_jobs.call_count, 1) + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_job_reattaches(self, mock_client_fn, _): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=False, + job_acknowledgement_timeout=-1, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [ + submit_pb2.UNKNOWN + for _ in range( + 1 + + ceil( + DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL + ) + ) + ] + ] + mock_client_fn.return_value = mock_client + self.context["ti"].xcom_pull.return_value = {"armada_job_id": DEFAULT_JOB_ID} + + operator.execute(self.context) + self.assertEqual(mock_client.submit_jobs.call_count, 0) + self.assertEqual(operator.job_id, DEFAULT_JOB_ID) + + +class TestArmadaOperatorDeferrable(unittest.IsolatedAsyncioTestCase): + def setUp(self): + # Set up a mock context + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_dag = MagicMock() + mock_dag.dag_id = DEFAULT_DAG_ID + self.context = { + "ti": mock_ti, + "run_id": DEFAULT_RUN_ID, + "dag": mock_dag, + } + + @patch("armada.operators.armada.ArmadaOperator.defer") + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_execute_deferred(self, mock_client_fn, mock_defer_fn): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=True, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] + ) + mock_client_fn.return_value = mock_client + self.context["ti"].xcom_pull.return_value = None + + operator.execute(self.context) + self.assertEqual(mock_client.submit_jobs.call_count, 1) + mock_defer_fn.assert_called_with( + timeout=operator.execution_timeout, + trigger=ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=operator.job_set_id, # Not relevant for the sake of test + channel_args=operator.channel_args, + poll_interval=operator.poll_interval, + tracking_message="", + job_acknowledgement_timeout=operator.job_acknowledgement_timeout, + job_request_namespace="default", + ), + method_name="_execute_complete", + ) + + def test_templating(self): + """Tests templating for both the job_prefix and the pod spec""" + prefix = "{{ run_id }}" + pod_arg = "{{ run_id }}" + + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="alpine:3.20.1", + args=[pod_arg], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + ), + ) + ], + ) + job = JobSubmitRequestItem(priority=1, pod_spec=pod, namespace="armada") + + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=job, + job_set_prefix=prefix, + task_id=DEFAULT_TASK_ID, + deferrable=True, + ) + + operator.render_template_fields(self.context) + + self.assertEqual(operator.job_set_prefix, "test_run_1") + self.assertEqual( + operator.job_request.pod_spec.containers[0].args[0], "test_run_1" + ) diff --git a/third_party/airflow/armada/operators/__init__.py b/third_party/airflow/test/triggers/__init__.py similarity index 100% rename from third_party/airflow/armada/operators/__init__.py rename to third_party/airflow/test/triggers/__init__.py diff --git a/third_party/airflow/test/triggers/test_armada.py b/third_party/airflow/test/triggers/test_armada.py new file mode 100644 index 00000000000..29ba4f20990 --- /dev/null +++ b/third_party/airflow/test/triggers/test_armada.py @@ -0,0 +1,207 @@ +import unittest +from unittest.mock import AsyncMock, patch, PropertyMock + +from airflow.triggers.base import TriggerEvent +from armada_client.armada.submit_pb2 import JobState +from armada_client.armada import submit_pb2, job_pb2 + +from armada.model import GrpcChannelArgs +from armada.triggers.armada import ArmadaTrigger + +DEFAULT_JOB_ID = "test_job" +DEFAULT_QUEUE = "test_queue" +DEFAULT_JOB_SET_ID = "test_job_set_id" +DEFAULT_POLLING_INTERVAL = 30 +DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 + + +class AsyncMock(unittest.mock.MagicMock): # noqa: F811 + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + +class TestArmadaTrigger(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.time = 0 + + def test_serialization(self): + trigger = ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=DEFAULT_JOB_SET_ID, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + poll_interval=30, + tracking_message="test tracking message", + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + job_request_namespace="default", + ) + classpath, kwargs = trigger.serialize() + self.assertEqual("armada.triggers.armada.ArmadaTrigger", classpath) + + rehydrated = ArmadaTrigger(**kwargs) + self.assertEqual(trigger, rehydrated) + + def _time_side_effect(self): + self.time += DEFAULT_POLLING_INTERVAL + return self.time + + @patch("time.time") + @patch("asyncio.sleep", new_callable=AsyncMock) + @patch("armada.triggers.armada.ArmadaTrigger.client", new_callable=PropertyMock) + async def test_execute(self, mock_client_fn, _, time_time): + time_time.side_effect = self._time_side_effect + + test_cases = [ + { + "name": "Job Succeeds", + "statuses": [JobState.RUNNING, JobState.SUCCEEDED], + "expected_responses": [ + TriggerEvent( + { + "status": "success", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} succeeded", + } + ) + ], + }, + { + "name": "Job Failed", + "statuses": [JobState.RUNNING, JobState.FAILED], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} did not succeed. " + f"Final status was FAILED", + } + ) + ], + }, + { + "name": "Job cancelled", + "statuses": [JobState.RUNNING, JobState.CANCELLED], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} did not succeed." + f" Final status was CANCELLED", + } + ) + ], + }, + { + "name": "Job unacknowledged", + "statuses": [JobState.UNKNOWN for _ in range(6)], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} not acknowledged wit" + f"hin timeout {DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT}.", + } + ) + ], + }, + { + "name": "Job preempted", + "statuses": [JobState.RUNNING, JobState.PREEMPTED], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} did not succeed." + f" Final status was PREEMPTED", + } + ) + ], + }, + { + "name": "Job Succeeds but takes a lot of transitions", + "statuses": [ + JobState.SUBMITTED, + JobState.RUNNING, + JobState.RUNNING, + JobState.RUNNING, + JobState.RUNNING, + JobState.RUNNING, + JobState.SUCCEEDED, + ], + "success": True, + "expected_responses": [ + TriggerEvent( + { + "status": "success", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} succeeded", + } + ) + ], + }, + ] + + for test_case in test_cases: + with self.subTest(test_case=test_case["name"]): + trigger = ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=DEFAULT_JOB_SET_ID, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + poll_interval=DEFAULT_POLLING_INTERVAL, + tracking_message="some tracking message", + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + job_request_namespace="default", + ) + + # Setup Mock Armada + mock_client = AsyncMock() + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in test_case["statuses"] + ] + mock_client.cancel_jobs.return_value = submit_pb2.CancellationResult( + cancelled_ids=[DEFAULT_JOB_ID] + ) + mock_client_fn.return_value = mock_client + responses = [gen async for gen in trigger.run()] + self.assertEqual(test_case["expected_responses"], responses) + self.assertEqual( + len(test_case["statuses"]), mock_client.get_job_status.call_count + ) + + @patch("time.sleep", return_value=None) + @patch("armada.triggers.armada.ArmadaTrigger.client", new_callable=PropertyMock) + async def test_unacknowledged_results_in_job_cancel(self, mock_client_fn, _): + trigger = ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=DEFAULT_JOB_SET_ID, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + poll_interval=DEFAULT_POLLING_INTERVAL, + tracking_message="some tracking message", + job_acknowledgement_timeout=-1, + job_request_namespace="default", + ) + + # Set up Mock Armada + mock_client = AsyncMock() + mock_client.cancel_jobs.return_value = submit_pb2.CancellationResult( + cancelled_ids=[DEFAULT_JOB_ID] + ) + mock_client_fn.return_value = mock_client + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [JobState.UNKNOWN, JobState.UNKNOWN] + ] + [gen async for gen in trigger.run()] + + self.assertEqual(mock_client.cancel_jobs.call_count, 1) diff --git a/third_party/airflow/tests/__init__.py b/third_party/airflow/tests/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/third_party/airflow/tests/integration/test_airflow_operator_logic.py b/third_party/airflow/tests/integration/test_airflow_operator_logic.py deleted file mode 100644 index f65ced29a67..00000000000 --- a/third_party/airflow/tests/integration/test_airflow_operator_logic.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -import uuid -import pytest -import threading - -from armada_client.armada import ( - submit_pb2, -) -from armada_client.client import ArmadaClient -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) -import grpc - -from armada.operators.jobservice import JobServiceClient -from armada.operators.utils import JobState, search_for_job_complete - - -@pytest.fixture(scope="session", name="jobservice") -def job_service_client() -> ArmadaClient: - server_name = os.environ.get("JOB_SERVICE_HOST", "localhost") - server_port = os.environ.get("JOB_SERVICE_PORT", "60003") - - return JobServiceClient( - channel=grpc.insecure_channel(f"{server_name}:{server_port}") - ) - - -@pytest.fixture(scope="session", name="client") -def no_auth_client() -> ArmadaClient: - server_name = os.environ.get("ARMADA_SERVER", "localhost") - server_port = os.environ.get("ARMADA_PORT", "50051") - - return ArmadaClient(channel=grpc.insecure_channel(f"{server_name}:{server_port}")) - - -def sleep_pod(image: str): - pod = core_v1.PodSpec( - containers=[ - core_v1.Container( - name="goodsleep", - image=image, - args=["sleep", "10s"], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="0.2"), - "memory": api_resource.Quantity(string="64Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="0.2"), - "memory": api_resource.Quantity(string="64Mi"), - }, - ), - ) - ], - ) - return [ - submit_pb2.JobSubmitRequestItem( - priority=1, pod_spec=pod, namespace="personal-anonymous" - ) - ] - - -def test_success_job(client: ArmadaClient, jobservice: JobServiceClient): - job_set_name = f"test-{uuid.uuid1()}" - job = client.submit_jobs( - queue="queue-a", - job_set_id=job_set_name, - job_request_items=sleep_pod(image="busybox"), - ) - job_id = job.job_response_items[0].job_id - - job_state, job_message = search_for_job_complete( - job_service_client=jobservice, - armada_queue="queue-a", - job_set_id=job_set_name, - airflow_task_name="test", - job_id=job_id, - ) - assert job_state == JobState.SUCCEEDED - assert job_message == f"Armada test:{job_id} succeeded" - - -def test_bad_job(client: ArmadaClient, jobservice: JobServiceClient): - job_set_name = f"test-{uuid.uuid1()}" - - job = client.submit_jobs( - queue="queue-a", - job_set_id=job_set_name, - job_request_items=sleep_pod(image="NOTACONTAINER"), - ) - job_id = job.job_response_items[0].job_id - - job_state, job_message = search_for_job_complete( - job_service_client=jobservice, - armada_queue="queue-a", - job_set_id=job_set_name, - airflow_task_name="test", - job_id=job_id, - ) - assert job_state == JobState.FAILED - assert job_message.startswith(f"Armada test:{job_id} failed") - - -job_set_name = "test" - - -def success_job(client: ArmadaClient, jobservice: JobServiceClient): - job = client.submit_jobs( - queue="queue-a", - job_set_id=job_set_name, - job_request_items=sleep_pod(image="busybox"), - ) - job_id = job.job_response_items[0].job_id - - job_state, job_message = search_for_job_complete( - job_service_client=jobservice, - armada_queue="queue-a", - job_set_id=job_set_name, - airflow_task_name="test", - job_id=job_id, - ) - - assert job_state == JobState.SUCCEEDED - assert job_message == f"Armada test:{job_id} succeeded" - - -@pytest.mark.skip(reason="we should not test performance in the CI.") -def test_parallel_execution(client: ArmadaClient, jobservice: JobServiceClient): - threads = [] - success_job(client=client, jobservice=jobservice) - for _ in range(30): - t = threading.Thread(target=success_job, args=[client, jobservice]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() - - -@pytest.mark.skip(reason="we should not test performance in the CI.") -def test_parallel_execution_large(client: ArmadaClient, jobservice: JobServiceClient): - threads = [] - success_job(client=client, jobservice=jobservice) - for _ in range(80): - t = threading.Thread(target=success_job, args=[client, jobservice]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() - - -@pytest.mark.skip(reason="we should not test performance in the CI.") -def test_parallel_execution_huge(client: ArmadaClient, jobservice: JobServiceClient): - threads = [] - success_job(client=client, jobservice=jobservice) - for _ in range(500): - t = threading.Thread(target=success_job, args=[client, jobservice]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() diff --git a/third_party/airflow/tests/unit/armada_client_mock.py b/third_party/airflow/tests/unit/armada_client_mock.py deleted file mode 100644 index fa3fd669a7f..00000000000 --- a/third_party/airflow/tests/unit/armada_client_mock.py +++ /dev/null @@ -1,35 +0,0 @@ -from google.protobuf import empty_pb2 -from armada_client.armada import submit_pb2_grpc, submit_pb2, event_pb2, event_pb2_grpc - - -class SubmitService(submit_pb2_grpc.SubmitServicer): - def CreateQueue(self, request, context): - return empty_pb2.Empty() - - def DeleteQueue(self, request, context): - return empty_pb2.Empty() - - def GetQueue(self, request, context): - return submit_pb2.Queue(name=request.name) - - def SubmitJobs(self, request, context): - submit_items = submit_pb2.JobSubmitResponseItem(job_id="mock") - - return submit_pb2.JobSubmitResponse(job_response_items=[submit_items]) - - def GetQueueInfo(self, request, context): - return submit_pb2.QueueInfo() - - def CancelJobs(self, request, context): - return submit_pb2.CancellationResult() - - def ReprioritizeJobs(self, request, context): - return submit_pb2.JobReprioritizeResponse() - - def UpdateQueue(self, request, context): - return empty_pb2.Empty() - - -class EventService(event_pb2_grpc.EventServicer): - def Watch(self, request, context): - return event_pb2.EventMessage() diff --git a/third_party/airflow/tests/unit/job_service_mock.py b/third_party/airflow/tests/unit/job_service_mock.py deleted file mode 100644 index 62d9641a9e5..00000000000 --- a/third_party/airflow/tests/unit/job_service_mock.py +++ /dev/null @@ -1,65 +0,0 @@ -import grpc - -from armada.jobservice import jobservice_pb2, jobservice_pb2_grpc - - -# TODO - Make this a bit smarter, so we can hit at least one full -# loop in search_for_job_complete. -def mock_dummy_mapper_terminal(request): - if request.job_id == "test_failed": - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.FAILED, error="Test Error" - ) - if request.job_id == "test_succeeded": - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.SUCCEEDED - ) - if request.job_id == "test_cancelled": - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.CANCELLED - ) - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.JOB_ID_NOT_FOUND - ) - - -class JobService(jobservice_pb2_grpc.JobServiceServicer): - def GetJobStatus(self, request, context): - return mock_dummy_mapper_terminal(request) - - def Health(self, request, context): - return jobservice_pb2.HealthCheckResponse( - status=jobservice_pb2.HealthCheckResponse.SERVING - ) - - -class JobServiceOccasionalError(jobservice_pb2_grpc.JobServiceServicer): - def __init__(self): - self.get_job_status_count = 0 - self.health_count = 0 - - def GetJobStatus(self, request, context): - self.get_job_status_count += 1 - if self.get_job_status_count % 3 == 0: - context.set_code(grpc.StatusCode.UNAVAILABLE) - context.set_details("Injected error") - raise Exception("Injected error") - - if self.get_job_status_count < 5: - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.RUNNING - ) - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.SUCCEEDED - ) - - def Health(self, request, context): - self.health_count += 1 - if self.health_count % 3 == 0: - context.set_code(grpc.StatusCode.UNAVAILABLE) - context.set_details("Injected error") - raise Exception("Injected error") - - return jobservice_pb2.HealthCheckResponse( - status=jobservice_pb2.HealthCheckResponse.SERVING - ) diff --git a/third_party/airflow/tests/unit/server_mock.py b/third_party/airflow/tests/unit/server_mock.py deleted file mode 100644 index bbadc20964f..00000000000 --- a/third_party/airflow/tests/unit/server_mock.py +++ /dev/null @@ -1,100 +0,0 @@ -from google.protobuf import empty_pb2 -from armada_client.armada import ( - submit_pb2_grpc, - submit_pb2, - event_pb2, - event_pb2_grpc, - health_pb2, -) - - -class SubmitService(submit_pb2_grpc.SubmitServicer): - def CreateQueue(self, request, context): - return empty_pb2.Empty() - - def DeleteQueue(self, request, context): - return empty_pb2.Empty() - - def GetQueue(self, request, context): - return submit_pb2.Queue(name=request.name) - - def SubmitJobs(self, request, context): - # read job_ids from request.job_request_items - job_ids = [f"job-{i}" for i in range(1, len(request.job_request_items) + 1)] - - job_response_items = [ - submit_pb2.JobSubmitResponseItem(job_id=job_id) for job_id in job_ids - ] - - return submit_pb2.JobSubmitResponse(job_response_items=job_response_items) - - def GetQueueInfo(self, request, context): - return submit_pb2.QueueInfo(name=request.name) - - def CancelJobs(self, request, context): - return submit_pb2.CancellationResult( - cancelled_ids=["job-1"], - ) - - def CancelJobSet(self, request, context): - return empty_pb2.Empty() - - def ReprioritizeJobs(self, request, context): - new_priority = request.new_priority - if len(request.job_ids) > 0: - job_id = request.job_ids[0] - results = { - f"{job_id}": new_priority, - } - - else: - queue = request.queue - job_set_id = request.job_set_id - - results = { - f"{queue}/{job_set_id}": new_priority, - } - - # convert the result dict into a list of tuples - # while also converting ints to strings - - results = [(k, str(v)) for k, v in results.items()] - - return submit_pb2.JobReprioritizeResponse(reprioritization_results=results) - - def UpdateQueue(self, request, context): - return empty_pb2.Empty() - - def CreateQueues(self, request, context): - return submit_pb2.BatchQueueCreateResponse( - failed_queues=[ - submit_pb2.QueueCreateResponse(queue=submit_pb2.Queue(name=queue.name)) - for queue in request.queues - ] - ) - - def UpdateQueues(self, request, context): - return submit_pb2.BatchQueueUpdateResponse( - failed_queues=[ - submit_pb2.QueueUpdateResponse(queue=submit_pb2.Queue(name=queue.name)) - for queue in request.queues - ] - ) - - def Health(self, request, context): - return health_pb2.HealthCheckResponse( - status=health_pb2.HealthCheckResponse.SERVING - ) - - -class EventService(event_pb2_grpc.EventServicer): - def GetJobSetEvents(self, request, context): - events = [event_pb2.EventStreamMessage()] - - for event in events: - yield event - - def Health(self, request, context): - return health_pb2.HealthCheckResponse( - status=health_pb2.HealthCheckResponse.SERVING - ) diff --git a/third_party/airflow/tests/unit/test_airflow_error.py b/third_party/airflow/tests/unit/test_airflow_error.py deleted file mode 100644 index 1e51c08e5ff..00000000000 --- a/third_party/airflow/tests/unit/test_airflow_error.py +++ /dev/null @@ -1,24 +0,0 @@ -from armada.operators.utils import JobState, airflow_error -from airflow.exceptions import AirflowException -import pytest - -testdata_success = [JobState.SUCCEEDED] - - -@pytest.mark.parametrize("state", testdata_success) -def test_airflow_error_successful(state): - airflow_error(state, "hello", "id") - - -testdata_error = [ - (JobState.FAILED, "The Armada job hello:id FAILED"), - (JobState.CANCELLED, "The Armada job hello:id CANCELLED"), - (JobState.JOB_ID_NOT_FOUND, "The Armada job hello:id JOB_ID_NOT_FOUND"), -] - - -@pytest.mark.parametrize("state, expected_exception_message", testdata_error) -def test_airflow_error_states(state, expected_exception_message): - with pytest.raises(AirflowException) as airflow: - airflow_error(state, "hello", "id") - assert str(airflow.value) == expected_exception_message diff --git a/third_party/airflow/tests/unit/test_airflow_operator_mock.py b/third_party/airflow/tests/unit/test_airflow_operator_mock.py deleted file mode 100644 index 1ab2d37ced1..00000000000 --- a/third_party/airflow/tests/unit/test_airflow_operator_mock.py +++ /dev/null @@ -1,217 +0,0 @@ -from airflow import DAG -from airflow.models.taskinstance import TaskInstance -from airflow.utils.context import Context -from armada_client.client import ArmadaClient -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) - -import grpc -from concurrent import futures -from armada_client.armada import submit_pb2_grpc, submit_pb2, event_pb2_grpc - -import pendulum -import pytest -from armada.operators.armada import ArmadaOperator, annotate_job_request_items -from armada.operators.jobservice import JobServiceClient -from armada.operators.utils import JobState, search_for_job_complete -from armada.jobservice import jobservice_pb2_grpc, jobservice_pb2 -from armada_client_mock import SubmitService, EventService -from job_service_mock import JobService - - -@pytest.fixture(scope="session", autouse=True) -def server_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server) - event_pb2_grpc.add_EventServicer_to_server(EventService(), server) - server.add_insecure_port("[::]:50099") - server.start() - - yield - server.stop(False) - - -@pytest.fixture(scope="session", autouse=True) -def job_service_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - jobservice_pb2_grpc.add_JobServiceServicer_to_server(JobService(), server) - server.add_insecure_port("[::]:60081") - server.start() - - yield - server.stop(False) - - -tester_client = ArmadaClient( - grpc.insecure_channel( - target="127.0.0.1:50099", - ) -) -tester_jobservice = JobServiceClient(grpc.insecure_channel(target="127.0.0.1:60081")) - - -def generate_pod_spec(name: str = "container-1") -> core_v1.PodSpec: - ps = core_v1.PodSpec( - containers=[ - core_v1.Container( - name=name, - image="busybox", - args=["sleep", "10s"], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - ), - ) - ], - ) - return ps - - -def sleep_job(): - pod = generate_pod_spec() - return [submit_pb2.JobSubmitRequestItem(priority=0, pod_spec=pod)] - - -def pre_template_sleep_job(): - pod = generate_pod_spec(name="name-{{ run_id }}") - return [submit_pb2.JobSubmitRequestItem(priority=0, pod_spec=pod)] - - -def expected_sleep_job(): - pod = generate_pod_spec(name="name-another-run-id") - return [submit_pb2.JobSubmitRequestItem(priority=0, pod_spec=pod)] - - -def test_job_service_health(): - health = tester_jobservice.health() - assert health.status == jobservice_pb2.HealthCheckResponse.SERVING - - -def test_mock_success_job(): - tester_client.submit_jobs( - queue="test", - job_set_id="test", - job_request_items=sleep_job(), - ) - - job_state, job_message = search_for_job_complete( - job_service_client=tester_jobservice, - armada_queue="test", - job_set_id="test", - airflow_task_name="test-mock", - job_id="test_succeeded", - ) - assert job_state == JobState.SUCCEEDED - assert job_message == "Armada test-mock:test_succeeded succeeded" - - -def test_mock_failed_job(): - tester_client.submit_jobs( - queue="test", - job_set_id="test", - job_request_items=sleep_job(), - ) - - job_state, job_message = search_for_job_complete( - job_service_client=tester_jobservice, - armada_queue="test", - job_set_id="test", - airflow_task_name="test-mock", - job_id="test_failed", - ) - assert job_state == JobState.FAILED - assert job_message.startswith("Armada test-mock:test_failed failed") - - -def test_mock_cancelled_job(): - tester_client.submit_jobs( - queue="test", - job_set_id="test", - job_request_items=sleep_job(), - ) - - job_state, job_message = search_for_job_complete( - job_service_client=tester_jobservice, - armada_queue="test", - job_set_id="test", - airflow_task_name="test-mock", - job_id="test_cancelled", - ) - assert job_state == JobState.CANCELLED - assert job_message == "Armada test-mock:test_cancelled cancelled" - - -def test_annotate_job_request_items(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - job_request_items = sleep_job() - task_id = "58896abbfr9" - operator = ArmadaOperator( - task_id=task_id, - name="armada-task", - armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=job_request_items, - lookout_url_template="http://127.0.0.1:8089", - ) - - task_instance = TaskInstance(operator) - dag = DAG( - dag_id="hello_armada", - start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule="@daily", - catchup=False, - default_args={"retries": 2}, - ) - context = {"ti": task_instance, "dag": dag, "run_id": "some-run-id"} - - result = annotate_job_request_items(context, job_request_items) - assert result[0].annotations == { - "armadaproject.io/taskId": task_id, - "armadaproject.io/taskRunId": "some-run-id", - "armadaproject.io/dagId": "hello_armada", - } - - -def test_parameterize_armada_operator(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - submitted_job_request_items = pre_template_sleep_job() - expected_job_request_items = expected_sleep_job() - task_id = "123456789ab" - operator = ArmadaOperator( - task_id=task_id, - name="armada-task", - armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submitted_job_request_items, - lookout_url_template="http://127.0.0.1:8089", - ) - task_instance = TaskInstance(operator) - dag = DAG( - dag_id="hello_armada", - start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule="@daily", - catchup=False, - default_args={"retries": 2}, - ) - context = Context(ti=task_instance, dag=dag, run_id="another-run-id") - - assert operator.job_request_items != expected_job_request_items - - operator.render_template_fields(context) - - assert operator.job_request_items == expected_job_request_items diff --git a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py deleted file mode 100644 index 0f156ed177e..00000000000 --- a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py +++ /dev/null @@ -1,171 +0,0 @@ -import copy - -import pytest - -from armada_client.armada import submit_pb2 -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) -from armada.operators.armada_deferrable import ArmadaDeferrableOperator -from armada.operators.grpc import CredentialsCallback - - -def test_serialize_armada_deferrable(): - grpc_chan_args = { - "target": "localhost:443", - "credentials_callback_args": { - "module_name": "channel_test", - "function_name": "get_credentials", - "function_kwargs": { - "example_arg": "test", - }, - }, - } - - pod = core_v1.PodSpec( - containers=[ - core_v1.Container( - name="sleep", - image="busybox", - args=["sleep", "10s"], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - ), - ) - ], - ) - - job_requests = [ - submit_pb2.JobSubmitRequestItem( - priority=1, - pod_spec=pod, - namespace="personal-anonymous", - annotations={"armadaproject.io/hello": "world"}, - ) - ] - - source = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test task", - armada_channel_args=grpc_chan_args, - job_service_channel_args=grpc_chan_args, - armada_queue="test-queue", - job_request_items=job_requests, - lookout_url_template="https://lookout.test.domain/", - poll_interval=5, - ) - - serialized = source.serialize() - assert serialized["name"] == source.name - - reconstituted = ArmadaDeferrableOperator(**serialized) - assert reconstituted == source - - -get_lookout_url_test_cases = [ - ( - "http://localhost:8089/jobs?job_id=", - "test_id", - "http://localhost:8089/jobs?job_id=test_id", - ), - ( - "https://lookout.armada.domain/jobs?job_id=", - "test_id", - "https://lookout.armada.domain/jobs?job_id=test_id", - ), - ("", "test_id", ""), - (None, "test_id", ""), -] - - -@pytest.mark.parametrize( - "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases -) -def test_get_lookout_url(lookout_url_template, job_id, expected_url): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template=lookout_url_template, - ) - - assert operator._get_lookout_url(job_id) == expected_url - - -def test_deepcopy_operator(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def test_deepcopy_operator_with_grpc_credentials_callback(): - armada_channel_args = { - "target": "127.0.0.1:50051", - "credentials_callback_args": { - "module_name": "tests.unit.test_armada_operator", - "function_name": "__example_test_callback", - "function_kwargs": { - "test_arg": "fake_arg", - }, - }, - } - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def __example_test_callback(foo=None): - return f"fake_cred {foo}" - - -def test_credentials_callback(): - callback = CredentialsCallback( - module_name="test_armada_operator", - function_name="__example_test_callback", - function_kwargs={"foo": "bar"}, - ) - - result = callback.call() - assert result == "fake_cred bar" diff --git a/third_party/airflow/tests/unit/test_armada_operator.py b/third_party/airflow/tests/unit/test_armada_operator.py deleted file mode 100644 index 571d634dc70..00000000000 --- a/third_party/airflow/tests/unit/test_armada_operator.py +++ /dev/null @@ -1,197 +0,0 @@ -import copy -from unittest.mock import patch, Mock - -import grpc -import pytest - -from armada.jobservice import jobservice_pb2 -from armada.operators.armada import ArmadaOperator -from armada.operators.grpc import CredentialsCallback -from armada.operators.utils import JobState - -get_lookout_url_test_cases = [ - ( - "http://localhost:8089/jobs?job_id=", - "test_id", - "http://localhost:8089/jobs?job_id=test_id", - ), - ( - "https://lookout.armada.domain/jobs?job_id=", - "test_id", - "https://lookout.armada.domain/jobs?job_id=test_id", - ), - ("", "test_id", ""), - (None, "test_id", ""), -] - - -@pytest.mark.parametrize( - "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases -) -def test_get_lookout_url(lookout_url_template, job_id, expected_url): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template=lookout_url_template, - ) - - assert operator._get_lookout_url(job_id) == expected_url - - -def test_deepcopy_operator(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -@pytest.mark.skip("demonstrates how the old way of passing in credentials fails") -def test_deepcopy_operator_with_grpc_credentials(): - armada_channel_args = { - "target": "127.0.0.1:50051", - "credentials": grpc.composite_channel_credentials( - grpc.ssl_channel_credentials(), - grpc.metadata_call_credentials(("authorization", "fake_jwt")), - ), - } - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def test_deepcopy_operator_with_grpc_credentials_callback(): - armada_channel_args = { - "target": "127.0.0.1:50051", - "credentials_callback_args": { - "module_name": "tests.unit.test_armada_operator", - "function_name": "__example_test_callback", - "function_kwargs": { - "test_arg": "fake_arg", - }, - }, - } - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def __example_test_callback(foo=None): - return f"fake_cred {foo}" - - -def test_credentials_callback(): - callback = CredentialsCallback( - module_name="test_armada_operator", - function_name="__example_test_callback", - function_kwargs={"foo": "bar"}, - ) - - result = callback.call() - assert result == "fake_cred bar" - - -@patch("armada.operators.armada.search_for_job_complete") -@patch("armada.operators.armada.ArmadaClient", autospec=True) -@patch("armada.operators.armada.JobServiceClient", autospec=True) -def test_armada_operator_execute( - JobServiceClientMock, ArmadaClientMock, search_for_job_complete_mock -): - jsclient_mock = Mock() - jsclient_mock.health.return_value = jobservice_pb2.HealthCheckResponse( - status=jobservice_pb2.HealthCheckResponse.SERVING - ) - - JobServiceClientMock.return_value = jsclient_mock - - item = Mock() - item.job_id = "fake_id" - - job = Mock() - job.job_response_items = [ - item, - ] - - aclient_mock = Mock() - aclient_mock.submit_jobs.return_value = job - ArmadaClientMock.return_value = aclient_mock - - search_for_job_complete_mock.return_value = (JobState.SUCCEEDED, "No error") - - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="https://lookout.armada.domain/jobs?job_id=", - ) - - task_instance = Mock() - task_instance.task_id = "mock_task_id" - - dag = Mock() - dag.dag_id = "mock_dag_id" - - context = { - "run_id": "mock_run_id", - "ti": task_instance, - "dag": dag, - } - - try: - operator.execute(context) - except Exception as e: - assert False, f"{e}" - - jsclient_mock.health.assert_called() - aclient_mock.submit_jobs.assert_called() diff --git a/third_party/airflow/tests/unit/test_grpc.py b/third_party/airflow/tests/unit/test_grpc.py deleted file mode 100644 index 1e12b566067..00000000000 --- a/third_party/airflow/tests/unit/test_grpc.py +++ /dev/null @@ -1,26 +0,0 @@ -import armada.operators.grpc - - -def test_serialize_grpc_channel(): - src_chan_args = { - "target": "localhost:443", - "credentials_callback_args": { - "module_name": "channel_test", - "function_name": "get_credentials", - "function_kwargs": { - "example_arg": "test", - }, - }, - } - - source = armada.operators.grpc.GrpcChannelArguments(**src_chan_args) - - serialized = source.serialize() - assert serialized["target"] == src_chan_args["target"] - assert ( - serialized["credentials_callback_args"] - == src_chan_args["credentials_callback_args"] - ) - - reconstituted = armada.operators.grpc.GrpcChannelArguments(**serialized) - assert reconstituted == source diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete.py b/third_party/airflow/tests/unit/test_search_for_job_complete.py deleted file mode 100644 index 279d5f08e17..00000000000 --- a/third_party/airflow/tests/unit/test_search_for_job_complete.py +++ /dev/null @@ -1,75 +0,0 @@ -from armada.operators.utils import JobState, search_for_job_complete -from armada.jobservice import jobservice_pb2 - - -def test_failed_event(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.FAILED, error="Testing Failure" - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - ) - assert job_complete[0] == JobState.FAILED - assert ( - job_complete[1] == "Armada test:id failed\nfailed with reason Testing Failure" - ) - - -def test_successful_event(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.SUCCEEDED - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - ) - assert job_complete[0] == JobState.SUCCEEDED - assert job_complete[1] == "Armada test:id succeeded" - - -def test_cancelled_event(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.CANCELLED - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - ) - assert job_complete[0] == JobState.CANCELLED - assert job_complete[1] == "Armada test:id cancelled" - - -def test_job_id_not_found(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.JOB_ID_NOT_FOUND - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - time_out_for_failure=5, - ) - assert job_complete[0] == JobState.JOB_ID_NOT_FOUND - assert ( - job_complete[1] == "Armada test:id could not find a job id and\nhit a timeout" - ) diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py deleted file mode 100644 index a842fa994d3..00000000000 --- a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py +++ /dev/null @@ -1,152 +0,0 @@ -from concurrent import futures -import logging - -import grpc -import pytest -import pytest_asyncio - -from job_service_mock import JobService, JobServiceOccasionalError - -from armada.operators.jobservice_asyncio import JobServiceAsyncIOClient -from armada.operators.jobservice import default_jobservice_channel_options -from armada.operators.utils import JobState, search_for_job_complete_async -from armada.jobservice import jobservice_pb2_grpc, jobservice_pb2 - - -@pytest.fixture -def server_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - jobservice_pb2_grpc.add_JobServiceServicer_to_server(JobService(), server) - server.add_insecure_port("[::]:50100") - server.start() - yield - server.stop(False) - - -@pytest_asyncio.fixture(scope="function") -async def js_aio_client(server_mock): - channel = grpc.aio.insecure_channel( - target="127.0.0.1:50100", - options={ - "grpc.keepalive_time_ms": 30000, - }.items(), - ) - await channel.channel_ready() - assert channel.get_state(True) == grpc.ChannelConnectivity.READY - - return JobServiceAsyncIOClient(channel) - - -@pytest.fixture -def server_occasional_error_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - jobservice_pb2_grpc.add_JobServiceServicer_to_server( - JobServiceOccasionalError(), server - ) - server.add_insecure_port("[::]:50101") - server.start() - yield - server.stop(False) - - -@pytest_asyncio.fixture(scope="function") -async def js_aio_retry_client(server_occasional_error_mock): - channel = grpc.aio.insecure_channel( - target="127.0.0.1:50101", - options=default_jobservice_channel_options, - ) - await channel.channel_ready() - assert channel.get_state(True) == grpc.ChannelConnectivity.READY - - return JobServiceAsyncIOClient(channel) - - -@pytest.mark.asyncio -async def test_failed_event(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_failed", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.FAILED - assert ( - job_complete[1] - == "Armada test:test_failed failed\nfailed with reason Test Error" - ) - - -@pytest.mark.asyncio -async def test_successful_event(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_succeeded", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.SUCCEEDED - assert job_complete[1] == "Armada test:test_succeeded succeeded" - - -@pytest.mark.asyncio -async def test_cancelled_event(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_cancelled", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.CANCELLED - assert job_complete[1] == "Armada test:test_cancelled cancelled" - - -@pytest.mark.asyncio -async def test_job_id_not_found(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - time_out_for_failure=5, - job_service_client=js_aio_client, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.JOB_ID_NOT_FOUND - assert ( - job_complete[1] == "Armada test:id could not find a job id and\nhit a timeout" - ) - - -@pytest.mark.asyncio -async def test_healthy(js_aio_client): - health = await js_aio_client.health() - assert health.status == jobservice_pb2.HealthCheckResponse.SERVING - - -@pytest.mark.asyncio -async def test_error_retry(js_aio_retry_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_succeeded", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_retry_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.SUCCEEDED - assert job_complete[1] == "Armada test:test_succeeded succeeded" diff --git a/third_party/airflow/tox.ini b/third_party/airflow/tox.ini index abfb8db10a5..09dd8ce15ea 100644 --- a/third_party/airflow/tox.ini +++ b/third_party/airflow/tox.ini @@ -13,7 +13,7 @@ allowlist_externals = find xargs commands = - coverage run -m pytest tests/unit/ + coverage run -m unittest discover coverage xml # This executes the dag files in examples but really only checks for imports and python errors bash -c "find examples/ -maxdepth 1 -type f -name *.py | xargs python3" @@ -21,18 +21,18 @@ commands = [testenv:format] extras = format commands = - black --check armada/operators tests/ examples/ -# Disabled until mypy reaches v1.0 -# mypy --ignore-missing-imports armada/operators tests/ examples/ - flake8 armada/operators tests/ examples/ + black armada/ test/ examples/ +# Disabled until mypy reaches v1.0 +# mypy --ignore-missing-imports armada/operators test/ examples/ + flake8 armada/ test/ examples/ -[testenv:format-code] +[testenv:format-check] extras = format commands = - black armada/operators tests/ examples/ -# Disabled until mypy reaches v1.0 -# mypy --ignore-missing-imports armada/operators tests/ examples/ - flake8 armada/operators tests/ examples/ + black --check armada/ test/ examples/ +# Disabled until mypy reaches v1.0 +# mypy --ignore-missing-imports armada/operators test/ examples/ + flake8 armada/ test/ examples/ [testenv:docs] basepython = python3.10