diff --git a/Cargo.lock b/Cargo.lock index b1724a995b2f..6ac34777d3e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3609,6 +3609,7 @@ dependencies = [ "match-template", "rand 0.8.5", "serde", + "tokio", "typetag", ] diff --git a/src/query/expression/src/utils/udf_client.rs b/src/query/expression/src/utils/udf_client.rs index 515f548a41da..9a434df0ce3c 100644 --- a/src/query/expression/src/utils/udf_client.rs +++ b/src/query/expression/src/utils/udf_client.rs @@ -35,6 +35,8 @@ use crate::DataSchema; const UDF_TCP_KEEP_ALIVE_SEC: u64 = 30; const UDF_HTTP2_KEEP_ALIVE_INTERVAL_SEC: u64 = 60; const UDF_KEEP_ALIVE_TIMEOUT_SEC: u64 = 20; +// 4MB by default, we use 16G +// max_encoding_message_size is usize::max by default const MAX_DECODING_MESSAGE_SIZE: usize = 16 * 1024 * 1024 * 1024; #[derive(Debug, Clone)] diff --git a/src/query/pipeline/transforms/Cargo.toml b/src/query/pipeline/transforms/Cargo.toml index 1735c8a1d012..4a95b79513ed 100644 --- a/src/query/pipeline/transforms/Cargo.toml +++ b/src/query/pipeline/transforms/Cargo.toml @@ -22,6 +22,7 @@ async-trait = { workspace = true } jsonb = { workspace = true } match-template = { workspace = true } serde = { workspace = true } +tokio = { workspace = true } typetag = { workspace = true } [dev-dependencies] diff --git a/src/query/pipeline/transforms/src/processors/transforms/mod.rs b/src/query/pipeline/transforms/src/processors/transforms/mod.rs index a9420663f2cf..bbaac634627c 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/mod.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/mod.rs @@ -25,6 +25,7 @@ mod transform_dummy; mod transform_multi_sort_merge; mod transform_sort_merge_base; +mod transform_retry_async; mod transform_sort_merge; mod transform_sort_merge_limit; pub mod transform_sort_partial; @@ -38,6 +39,7 @@ pub use transform_blocking::*; pub use transform_compact::*; pub use transform_dummy::*; pub use transform_multi_sort_merge::try_add_multi_sort_merge; +pub use transform_retry_async::*; pub use transform_sort_merge::sort_merge; pub use transform_sort_merge::*; pub use transform_sort_merge_base::*; diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_retry_async.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_retry_async.rs new file mode 100644 index 000000000000..6622c639e488 --- /dev/null +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_retry_async.rs @@ -0,0 +1,74 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use databend_common_exception::Result; +use databend_common_expression::DataBlock; + +use super::AsyncTransform; + +pub trait AsyncRetry: AsyncTransform { + fn retry_on(&self, err: &databend_common_exception::ErrorCode) -> bool; + fn retry_strategy(&self) -> RetryStrategy; +} + +#[derive(Clone)] +pub struct RetryStrategy { + pub retry_times: usize, + pub retry_sleep_duration: Option, +} + +pub struct AsyncRetryWrapper { + t: T, +} + +impl AsyncRetryWrapper { + pub fn create(inner: T) -> Self { + Self { t: inner } + } +} + +#[async_trait::async_trait] +impl AsyncTransform for AsyncRetryWrapper { + const NAME: &'static str = T::NAME; + + async fn transform(&mut self, data: DataBlock) -> Result { + let strategy = self.t.retry_strategy(); + for _ in 0..strategy.retry_times { + match self.t.transform(data.clone()).await { + Ok(v) => return Ok(v), + Err(e) => { + if !self.t.retry_on(&e) { + return Err(e); + } + if let Some(duration) = strategy.retry_sleep_duration { + tokio::time::sleep(duration).await; + } + } + } + } + self.t.transform(data.clone()).await + } + + fn name(&self) -> String { + Self::NAME.to_string() + } + + async fn on_start(&mut self) -> Result<()> { + self.t.on_start().await + } + + async fn on_finish(&mut self) -> Result<()> { + self.t.on_finish().await + } +} diff --git a/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs index 473471c358af..79c55f40d3fb 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs @@ -24,8 +24,11 @@ use databend_common_expression::DataBlock; use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::FunctionContext; +use databend_common_pipeline_transforms::processors::AsyncRetry; +use databend_common_pipeline_transforms::processors::AsyncRetryWrapper; use databend_common_pipeline_transforms::processors::AsyncTransform; use databend_common_pipeline_transforms::processors::AsyncTransformer; +use databend_common_pipeline_transforms::processors::RetryStrategy; use databend_common_sql::executor::physical_plans::UdfFunctionDesc; use crate::pipelines::processors::InputPort; @@ -44,10 +47,22 @@ impl TransformUdfServer { input: Arc, output: Arc, ) -> Result> { - Ok(AsyncTransformer::create(input, output, Self { - func_ctx, - funcs, - })) + let s = Self { func_ctx, funcs }; + let retry_wrapper = AsyncRetryWrapper::create(s); + Ok(AsyncTransformer::create(input, output, retry_wrapper)) + } +} + +impl AsyncRetry for TransformUdfServer { + fn retry_on(&self, _err: &databend_common_exception::ErrorCode) -> bool { + true + } + + fn retry_strategy(&self) -> RetryStrategy { + RetryStrategy { + retry_times: 64, + retry_sleep_duration: Some(tokio::time::Duration::from_millis(500)), + } } } diff --git a/tests/sqllogictests/suites/udf_server/udf_server_test.test b/tests/sqllogictests/suites/udf_server/udf_server_test.test index 46757934a18e..8681280567ba 100644 --- a/tests/sqllogictests/suites/udf_server/udf_server_test.test +++ b/tests/sqllogictests/suites/udf_server/udf_server_test.test @@ -20,6 +20,9 @@ DROP FUNCTION IF EXISTS bool_select; statement ok DROP FUNCTION IF EXISTS gcd; +statement ok +DROP FUNCTION IF EXISTS gcd_error; + statement ok DROP FUNCTION IF EXISTS decimal_div; @@ -86,6 +89,9 @@ CREATE OR REPLACE FUNCTION bool_select (BOOLEAN, BIGINT, BIGINT) RETURNS BIGINT statement ok CREATE OR REPLACE FUNCTION gcd (INT, INT) RETURNS INT LANGUAGE python HANDLER = 'gcd' ADDRESS = 'http://0.0.0.0:8815'; +statement ok +CREATE OR REPLACE FUNCTION gcd_error (INT, INT) RETURNS INT LANGUAGE python HANDLER = 'gcd_error' ADDRESS = 'http://0.0.0.0:8815'; + statement ok CREATE OR REPLACE FUNCTION split_and_join (VARCHAR, VARCHAR, VARCHAR) RETURNS VARCHAR LANGUAGE python HANDLER = 'split_and_join' ADDRESS = 'http://0.0.0.0:8815'; @@ -177,6 +183,10 @@ SELECT gcd(a,b) d from (select number + 1 a, a * 2 b from numbers(3) where numb 2 3 +query I +SELECT sum(gcd_error(a,b)) from (select number + 1 a, a * 2 b from numbers(3000)) +---- +4501500 statement ok create or replace table gcd_target(id int); diff --git a/tests/udf/udf_server.py b/tests/udf/udf_server.py index b0106d7bf777..9bd9b9d91950 100644 --- a/tests/udf/udf_server.py +++ b/tests/udf/udf_server.py @@ -56,6 +56,22 @@ def gcd(x: int, y: int) -> int: (x, y) = (y, x % y) return x +gcd_error_cnt = 0 +@udf( + name="gcd_error", + input_types=["INT", "INT"], + result_type="INT", + skip_null=True, +) +def gcd_error(x: int, y: int) -> int: + global gcd_error_cnt + if y % 2 == 0 and gcd_error_cnt <= 3: + gcd_error_cnt += 1 + raise ValueError("gcd_error") + while y != 0: + (x, y) = (y, x % y) + return x + @udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") def split_and_join(s: str, split_s: str, join_s: str) -> str: @@ -310,6 +326,7 @@ def wait_concurrent(x): udf_server.add_function(binary_reverse) udf_server.add_function(bool_select) udf_server.add_function(gcd) + udf_server.add_function(gcd_error) udf_server.add_function(split_and_join) udf_server.add_function(decimal_div) udf_server.add_function(hex_to_dec)