diff --git a/include/hareflow/detail/connection.h b/include/hareflow/detail/connection.h index cf5ea2e..57ea53d 100644 --- a/include/hareflow/detail/connection.h +++ b/include/hareflow/detail/connection.h @@ -3,6 +3,7 @@ #include #include #include +#include #ifdef _WIN32 # pragma warning(push) @@ -36,8 +37,22 @@ class Connection : public std::enable_shared_from_this std::future read(); private: - using tcp_socket = boost::asio::ip::tcp::socket; - using ssl_stream = boost::asio::ssl::stream; + using tcp_socket = boost::asio::ip::tcp::socket; + using ssl_context = boost::asio::ssl::context; + using ssl_stream = boost::asio::ssl::stream; + + class SslAdapter + { + public: + SslAdapter(tcp_socket& wrapped_socket, const std::string& host, bool verify_host); + + ssl_stream& stream(); + + private: + ssl_context m_context; + std::unique_ptr m_stream; + }; + class OutboxEntry { public: @@ -84,7 +99,7 @@ class Connection : public std::enable_shared_from_this boost::asio::io_context& m_io_context; boost::asio::io_context::strand m_strand; tcp_socket m_socket; - std::unique_ptr m_ssl_stream; + std::optional m_ssl; std::string m_host; std::uint16_t m_port; bool m_connection_failed; diff --git a/src/connection.cpp b/src/connection.cpp index 71ca63d..da2c556 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -29,8 +29,8 @@ void Connection::connect() io::connect(m_socket, endpoints); m_socket.set_option(tcp::no_delay(true)); m_socket.set_option(tcp_socket::keep_alive(true)); - if (m_ssl_stream) { - m_ssl_stream->handshake(ssl_stream::client); + if (m_ssl) { + m_ssl->stream().handshake(ssl_stream::client); } m_connection_failed = false; } catch (const boost::system::system_error& e) { @@ -135,7 +135,7 @@ Connection::Connection(std::string host, std::uint16_t port, bool use_ssl, bool : m_io_context(IoContextHolder::get()), m_strand(m_io_context), m_socket(m_io_context), - m_ssl_stream(), + m_ssl(), m_host(std::move(host)), m_port(port), m_connection_failed(true), @@ -152,14 +152,7 @@ Connection::Connection(std::string host, std::uint16_t port, bool use_ssl, bool m_buffer_pool() { if (use_ssl) { - ssl::context ssl_context(ssl::context::tls); - ssl_context.set_options(ssl::context::default_workarounds | ssl::context::no_tlsv1 | ssl::context::no_tlsv1_1); - if (verify_host) { - ssl_context.set_default_verify_paths(); - ssl_context.set_verify_mode(ssl::verify_peer); - ssl_context.set_verify_callback(ssl::rfc2818_verification(m_host)); - } - m_ssl_stream = std::make_unique(m_socket, ssl_context); + m_ssl = SslAdapter(m_socket, m_host, verify_host); } } @@ -190,8 +183,8 @@ void Connection::send_outbox() buffers.emplace_back(entry->asio_buffer()); } - if (m_ssl_stream) { - io::async_write(*m_ssl_stream, buffers, io::bind_executor(m_strand, callback)); + if (m_ssl) { + io::async_write(m_ssl->stream(), buffers, io::bind_executor(m_strand, callback)); } else { io::async_write(m_socket, buffers, io::bind_executor(m_strand, callback)); } @@ -237,15 +230,15 @@ void Connection::start_async_read_chain() start_async_read_chain(); }; - if (m_ssl_stream) { - io::async_read(*m_ssl_stream, buffer->as_asio_buffer() + sizeof(std::uint32_t), io::bind_executor(m_strand, post_to_inbox)); + if (m_ssl) { + io::async_read(m_ssl->stream(), buffer->as_asio_buffer() + sizeof(std::uint32_t), io::bind_executor(m_strand, post_to_inbox)); } else { io::async_read(m_socket, buffer->as_asio_buffer() + sizeof(std::uint32_t), io::bind_executor(m_strand, post_to_inbox)); } }; - if (m_ssl_stream) { - io::async_read(*m_ssl_stream, buffer->as_asio_buffer(), io::bind_executor(m_strand, callback)); + if (m_ssl) { + io::async_read(m_ssl->stream(), buffer->as_asio_buffer(), io::bind_executor(m_strand, callback)); } else { io::async_read(m_socket, buffer->as_asio_buffer(), io::bind_executor(m_strand, callback)); } @@ -277,4 +270,20 @@ void Connection::start_reader_idle_monitoring() })); } +Connection::SslAdapter::SslAdapter(tcp_socket& wrapped_socket, const std::string& host, bool verify_host) : m_context(ssl::context::tls_client), m_stream() +{ + m_context.set_options(ssl_context::default_workarounds | ssl_context::no_tlsv1 | ssl_context::no_tlsv1_1); + if (verify_host) { + m_context.set_default_verify_paths(); + m_context.set_verify_mode(ssl::verify_peer); + m_context.set_verify_callback(ssl::rfc2818_verification(host)); + } + m_stream = std::make_unique(wrapped_socket, m_context); +} + +Connection::ssl_stream& Connection::SslAdapter::stream() +{ + return *m_stream; +} + } // namespace hareflow::detail \ No newline at end of file diff --git a/src/consumer_impl.cpp b/src/consumer_impl.cpp index ad2e324..aa4a064 100644 --- a/src/consumer_impl.cpp +++ b/src/consumer_impl.cpp @@ -186,8 +186,13 @@ void ConsumerImpl::internal_stop() void ConsumerImpl::internal_store_offset(std::uint64_t offset) { if (!m_name.empty() && m_client != nullptr) { - m_client->store_offset(m_name, m_stream, offset); - m_messages_since_last_persist = 0; + try { + m_client->store_offset(m_name, m_stream, offset); + m_messages_since_last_persist = 0; + } catch (const StreamException& e) { + Logger::warn("Failed to persist cursor: {}", e.what()); + } + if (m_persist_cursor_task.valid()) { m_persist_cursor_task.schedule_after(m_auto_cursor_config->force_persist_delay()); }