Skip to content

Commit

Permalink
Add ability to specify custom timeout for sql operations
Browse files Browse the repository at this point in the history
  • Loading branch information
prefiks committed Jul 4, 2024
1 parent b978a47 commit 25b78b7
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions src/ejabberd_sql.erl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
%% External exports
-export([start_link/2,
sql_query/2,
sql_query/3,
sql_query_t/1,
sql_transaction/2,
sql_transaction/3,
sql_bloc/2,
sql_bloc/3,
abort/1,
restart/1,
use_new_schema/0,
Expand Down Expand Up @@ -84,7 +87,8 @@
reconnect_count = 0 :: non_neg_integer(),
host :: binary(),
pending_requests :: p1_queue:queue(),
overload_reported :: undefined | integer()}).
overload_reported :: undefined | integer(),
timeout :: pos_integer()}).

-define(STATE_KEY, ejabberd_sql_state).
-define(NESTING_KEY, ejabberd_sql_nesting_level).
Expand Down Expand Up @@ -120,35 +124,48 @@ start_link(Host, I) ->
p1_fsm:start_link({local, Proc}, ?MODULE, [Host],
fsm_limit_opts() ++ ?FSMOPTS).

-spec sql_query(binary(), sql_query(T), pos_integer()) -> sql_query_result(T).
sql_query(Host, Query, Timeout) ->
sql_call(Host, {sql_query, Query}, Timeout).

-spec sql_query(binary(), sql_query(T)) -> sql_query_result(T).
sql_query(Host, Query) ->
sql_call(Host, {sql_query, Query}).
sql_query(Host, Query, query_timeout(Host)).

%% SQL transaction based on a list of queries
%% This function automatically
-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) ->
-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T), pos_integer()) ->
{atomic, T} |
{aborted, any()}.
sql_transaction(Host, Queries)
sql_transaction(Host, Queries, Timeout)
when is_list(Queries) ->
F = fun () ->
lists:foreach(fun (Query) -> sql_query_t(Query) end,
Queries)
end,
sql_transaction(Host, F);
sql_transaction(Host, F, Timeout);
%% SQL transaction, based on a erlang anonymous function (F = fun)
sql_transaction(Host, F) when is_function(F) ->
case sql_call(Host, {sql_transaction, F}) of
sql_transaction(Host, F, Timeout) when is_function(F) ->
case sql_call(Host, {sql_transaction, F}, Timeout) of
{atomic, _} = Ret -> Ret;
{aborted, _} = Ret -> Ret;
Err -> {aborted, Err}
end.

-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) ->
{atomic, T} |
{aborted, any()}.
sql_transaction(Host, Queries) ->
sql_transaction(Host, Queries, query_timeout(Host)).

%% SQL bloc, based on a erlang anonymous function (F = fun)
sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}).
sql_bloc(Host, F, Timeout) ->
sql_call(Host, {sql_bloc, F}, Timeout).

sql_call(Host, Msg) ->
Timeout = query_timeout(Host),
sql_bloc(Host, F) ->
sql_bloc(Host, F, query_timeout(Host)).

sql_call(Host, Msg, Timeout) ->
case get(?STATE_KEY) of
undefined ->
sync_send_event(Host,
Expand Down Expand Up @@ -355,7 +372,8 @@ init([Host]) ->
QueueType = ejabberd_option:sql_queue_type(Host),
{ok, connecting,
#state{db_type = DBType, host = Host,
pending_requests = p1_queue:new(QueueType, max_fsm_queue())}}.
pending_requests = p1_queue:new(QueueType, max_fsm_queue()),
timeout = query_timeout(Host)}}.

connecting(connect, #state{host = Host} = State) ->
ConnectRes = case db_opts(Host) of
Expand Down Expand Up @@ -496,10 +514,12 @@ handle_reconnect(Reason, #state{host = Host, reconnect_count = RC} = State) ->
_ -> ok
end,
p1_fsm:send_event_after(StartInterval, connect),
{next_state, connecting, State#state{reconnect_count = RC + 1}}.
{next_state, connecting, State#state{reconnect_count = RC + 1,
timeout = query_timeout(Host)}}.

run_sql_cmd(Command, From, State, Timestamp) ->
case current_time() >= Timestamp of
CT = current_time(),
case CT >= Timestamp of
true ->
State1 = report_overload(State),
{next_state, session_established, State1};
Expand All @@ -510,8 +530,9 @@ run_sql_cmd(Command, From, State, Timestamp) ->
State#state.pending_requests),
handle_reconnect(Reason, State#state{pending_requests = PR})
after 0 ->
Timeout = min(query_timeout(State#state.host), Timestamp - CT),
put(?NESTING_KEY, ?TOP_LEVEL_TXN),
put(?STATE_KEY, State),
put(?STATE_KEY, State#state{timeout = Timeout}),
abort_on_driver_error(outer_op(Command), From, Timestamp)
end
end.
Expand Down Expand Up @@ -726,7 +747,7 @@ sql_query_internal(F) when is_function(F) ->
sql_query_internal(Query) ->
State = get(?STATE_KEY),
?DEBUG("SQL: \"~ts\"", [Query]),
QueryTimeout = query_timeout(State#state.host),
QueryTimeout = State#state.timeout,
Res = case State#state.db_type of
odbc ->
to_odbc(odbc:sql_query(State#state.db_ref, [Query],
Expand Down

0 comments on commit 25b78b7

Please sign in to comment.