Skip to content

Commit

Permalink
Fix ssl implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
yrivardmulrooney committed May 19, 2022
1 parent f5a0b08 commit 341b4af
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 22 deletions.
21 changes: 18 additions & 3 deletions include/hareflow/detail/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <chrono>
#include <deque>
#include <memory>
#include <optional>

#ifdef _WIN32
# pragma warning(push)
Expand Down Expand Up @@ -36,8 +37,22 @@ class Connection : public std::enable_shared_from_this<Connection>
std::future<BinaryBuffer> read();

private:
using tcp_socket = boost::asio::ip::tcp::socket;
using ssl_stream = boost::asio::ssl::stream<tcp_socket&>;
using tcp_socket = boost::asio::ip::tcp::socket;
using ssl_context = boost::asio::ssl::context;
using ssl_stream = boost::asio::ssl::stream<tcp_socket&>;

class SslAdapter
{
public:
SslAdapter(tcp_socket& wrapped_socket, const std::string& host, bool verify_host);

ssl_stream& stream();

private:
ssl_context m_context;
std::unique_ptr<ssl_stream> m_stream;
};

class OutboxEntry
{
public:
Expand Down Expand Up @@ -84,7 +99,7 @@ class Connection : public std::enable_shared_from_this<Connection>
boost::asio::io_context& m_io_context;
boost::asio::io_context::strand m_strand;
tcp_socket m_socket;
std::unique_ptr<ssl_stream> m_ssl_stream;
std::optional<SslAdapter> m_ssl;
std::string m_host;
std::uint16_t m_port;
bool m_connection_failed;
Expand Down
43 changes: 26 additions & 17 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ void Connection::connect()
io::connect(m_socket, endpoints);
m_socket.set_option(tcp::no_delay(true));
m_socket.set_option(tcp_socket::keep_alive(true));
if (m_ssl_stream) {
m_ssl_stream->handshake(ssl_stream::client);
if (m_ssl) {
m_ssl->stream().handshake(ssl_stream::client);
}
m_connection_failed = false;
} catch (const boost::system::system_error& e) {
Expand Down Expand Up @@ -135,7 +135,7 @@ Connection::Connection(std::string host, std::uint16_t port, bool use_ssl, bool
: m_io_context(IoContextHolder::get()),
m_strand(m_io_context),
m_socket(m_io_context),
m_ssl_stream(),
m_ssl(),
m_host(std::move(host)),
m_port(port),
m_connection_failed(true),
Expand All @@ -152,14 +152,7 @@ Connection::Connection(std::string host, std::uint16_t port, bool use_ssl, bool
m_buffer_pool()
{
if (use_ssl) {
ssl::context ssl_context(ssl::context::tls);
ssl_context.set_options(ssl::context::default_workarounds | ssl::context::no_tlsv1 | ssl::context::no_tlsv1_1);
if (verify_host) {
ssl_context.set_default_verify_paths();
ssl_context.set_verify_mode(ssl::verify_peer);
ssl_context.set_verify_callback(ssl::rfc2818_verification(m_host));
}
m_ssl_stream = std::make_unique<ssl_stream>(m_socket, ssl_context);
m_ssl = SslAdapter(m_socket, m_host, verify_host);
}
}

Expand Down Expand Up @@ -190,8 +183,8 @@ void Connection::send_outbox()
buffers.emplace_back(entry->asio_buffer());
}

if (m_ssl_stream) {
io::async_write(*m_ssl_stream, buffers, io::bind_executor(m_strand, callback));
if (m_ssl) {
io::async_write(m_ssl->stream(), buffers, io::bind_executor(m_strand, callback));
} else {
io::async_write(m_socket, buffers, io::bind_executor(m_strand, callback));
}
Expand Down Expand Up @@ -237,15 +230,15 @@ void Connection::start_async_read_chain()
start_async_read_chain();
};

if (m_ssl_stream) {
io::async_read(*m_ssl_stream, buffer->as_asio_buffer() + sizeof(std::uint32_t), io::bind_executor(m_strand, post_to_inbox));
if (m_ssl) {
io::async_read(m_ssl->stream(), buffer->as_asio_buffer() + sizeof(std::uint32_t), io::bind_executor(m_strand, post_to_inbox));
} else {
io::async_read(m_socket, buffer->as_asio_buffer() + sizeof(std::uint32_t), io::bind_executor(m_strand, post_to_inbox));
}
};

if (m_ssl_stream) {
io::async_read(*m_ssl_stream, buffer->as_asio_buffer(), io::bind_executor(m_strand, callback));
if (m_ssl) {
io::async_read(m_ssl->stream(), buffer->as_asio_buffer(), io::bind_executor(m_strand, callback));
} else {
io::async_read(m_socket, buffer->as_asio_buffer(), io::bind_executor(m_strand, callback));
}
Expand Down Expand Up @@ -277,4 +270,20 @@ void Connection::start_reader_idle_monitoring()
}));
}

Connection::SslAdapter::SslAdapter(tcp_socket& wrapped_socket, const std::string& host, bool verify_host) : m_context(ssl::context::tls_client), m_stream()
{
m_context.set_options(ssl_context::default_workarounds | ssl_context::no_tlsv1 | ssl_context::no_tlsv1_1);
if (verify_host) {
m_context.set_default_verify_paths();
m_context.set_verify_mode(ssl::verify_peer);
m_context.set_verify_callback(ssl::rfc2818_verification(host));
}
m_stream = std::make_unique<ssl_stream>(wrapped_socket, m_context);
}

Connection::ssl_stream& Connection::SslAdapter::stream()
{
return *m_stream;
}

} // namespace hareflow::detail
9 changes: 7 additions & 2 deletions src/consumer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,13 @@ void ConsumerImpl::internal_stop()
void ConsumerImpl::internal_store_offset(std::uint64_t offset)
{
if (!m_name.empty() && m_client != nullptr) {
m_client->store_offset(m_name, m_stream, offset);
m_messages_since_last_persist = 0;
try {
m_client->store_offset(m_name, m_stream, offset);
m_messages_since_last_persist = 0;
} catch (const StreamException& e) {
Logger::warn("Failed to persist cursor: {}", e.what());
}

if (m_persist_cursor_task.valid()) {
m_persist_cursor_task.schedule_after(m_auto_cursor_config->force_persist_delay());
}
Expand Down

0 comments on commit 341b4af

Please sign in to comment.