diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl index 5e5549c45b8..a803b41394d 100644 --- a/src/ejabberd_sql.erl +++ b/src/ejabberd_sql.erl @@ -32,9 +32,12 @@ %% External exports -export([start_link/2, sql_query/2, + sql_query/3, sql_query_t/1, sql_transaction/2, + sql_transaction/3, sql_bloc/2, + sql_bloc/3, abort/1, restart/1, use_new_schema/0, @@ -84,7 +87,8 @@ reconnect_count = 0 :: non_neg_integer(), host :: binary(), pending_requests :: p1_queue:queue(), - overload_reported :: undefined | integer()}). + overload_reported :: undefined | integer(), + timeout :: pos_integer()}). -define(STATE_KEY, ejabberd_sql_state). -define(NESTING_KEY, ejabberd_sql_nesting_level). @@ -120,35 +124,48 @@ start_link(Host, I) -> p1_fsm:start_link({local, Proc}, ?MODULE, [Host], fsm_limit_opts() ++ ?FSMOPTS). +-spec sql_query(binary(), sql_query(T), pos_integer()) -> sql_query_result(T). +sql_query(Host, Query, Timeout) -> + sql_call(Host, {sql_query, Query}, Timeout). + -spec sql_query(binary(), sql_query(T)) -> sql_query_result(T). sql_query(Host, Query) -> - sql_call(Host, {sql_query, Query}). + sql_query(Host, Query, query_timeout(Host)). %% SQL transaction based on a list of queries %% This function automatically --spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) -> +-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T), pos_integer()) -> {atomic, T} | {aborted, any()}. -sql_transaction(Host, Queries) +sql_transaction(Host, Queries, Timeout) when is_list(Queries) -> F = fun () -> lists:foreach(fun (Query) -> sql_query_t(Query) end, Queries) end, - sql_transaction(Host, F); + sql_transaction(Host, F, Timeout); %% SQL transaction, based on a erlang anonymous function (F = fun) -sql_transaction(Host, F) when is_function(F) -> - case sql_call(Host, {sql_transaction, F}) of +sql_transaction(Host, F, Timeout) when is_function(F) -> + case sql_call(Host, {sql_transaction, F}, Timeout) of {atomic, _} = Ret -> Ret; {aborted, _} = Ret -> Ret; Err -> {aborted, Err} end. +-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) -> + {atomic, T} | + {aborted, any()}. +sql_transaction(Host, Queries) -> + sql_transaction(Host, Queries, query_timeout(Host)). + %% SQL bloc, based on a erlang anonymous function (F = fun) -sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}). +sql_bloc(Host, F, Timeout) -> + sql_call(Host, {sql_bloc, F}, Timeout). -sql_call(Host, Msg) -> - Timeout = query_timeout(Host), +sql_bloc(Host, F) -> + sql_bloc(Host, F, query_timeout(Host)). + +sql_call(Host, Msg, Timeout) -> case get(?STATE_KEY) of undefined -> sync_send_event(Host, @@ -355,7 +372,8 @@ init([Host]) -> QueueType = ejabberd_option:sql_queue_type(Host), {ok, connecting, #state{db_type = DBType, host = Host, - pending_requests = p1_queue:new(QueueType, max_fsm_queue())}}. + pending_requests = p1_queue:new(QueueType, max_fsm_queue()), + timeout = query_timeout(Host)}}. connecting(connect, #state{host = Host} = State) -> ConnectRes = case db_opts(Host) of @@ -496,10 +514,12 @@ handle_reconnect(Reason, #state{host = Host, reconnect_count = RC} = State) -> _ -> ok end, p1_fsm:send_event_after(StartInterval, connect), - {next_state, connecting, State#state{reconnect_count = RC + 1}}. + {next_state, connecting, State#state{reconnect_count = RC + 1, + timeout = query_timeout(Host)}}. run_sql_cmd(Command, From, State, Timestamp) -> - case current_time() >= Timestamp of + CT = current_time(), + case CT >= Timestamp of true -> State1 = report_overload(State), {next_state, session_established, State1}; @@ -510,8 +530,9 @@ run_sql_cmd(Command, From, State, Timestamp) -> State#state.pending_requests), handle_reconnect(Reason, State#state{pending_requests = PR}) after 0 -> + Timeout = min(query_timeout(State#state.host), Timestamp - CT), put(?NESTING_KEY, ?TOP_LEVEL_TXN), - put(?STATE_KEY, State), + put(?STATE_KEY, State#state{timeout = Timeout}), abort_on_driver_error(outer_op(Command), From, Timestamp) end end. @@ -726,7 +747,7 @@ sql_query_internal(F) when is_function(F) -> sql_query_internal(Query) -> State = get(?STATE_KEY), ?DEBUG("SQL: \"~ts\"", [Query]), - QueryTimeout = query_timeout(State#state.host), + QueryTimeout = State#state.timeout, Res = case State#state.db_type of odbc -> to_odbc(odbc:sql_query(State#state.db_ref, [Query],