Skip to content

Commit

Permalink
client-side-WRR-LB: Improve Client Side Weighted Round Robin lb polic…
Browse files Browse the repository at this point in the history
…y. (#37127)

Signed-off-by: Misha Efimov <mef@google.com>
  • Loading branch information
efimki authored Nov 21, 2024
1 parent dd6b7b7 commit 150e16d
Show file tree
Hide file tree
Showing 9 changed files with 402 additions and 67 deletions.
5 changes: 5 additions & 0 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ bug_fixes:
- area: tracers
change: |
Avoid possible overflow when setting span attributes in Dynatrace sampler.
- area: load_balancing
change: |
Fixed default host weight calculation of :ref:`client_side_weighted_round_robin
<envoy_v3_api_msg_extensions.load_balancing_policies.client_side_weighted_round_robin.v3.ClientSideWeightedRoundRobin>`
to properly handle even number of valid host weights.
removed_config_or_runtime:
# *Normally occurs at the end of the* :ref:`deprecation period <deprecated>`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ envoy_cc_library(
srcs = ["client_side_weighted_round_robin_lb.cc"],
hdrs = ["client_side_weighted_round_robin_lb.h"],
deps = [
"//envoy/thread_local:thread_local_interface",
"//source/common/common:callback_impl_lib",
"//source/common/orca:orca_load_metrics_lib",
"//source/extensions/load_balancing_policies/common:load_balancer_lib",
"//source/extensions/load_balancing_policies/round_robin:round_robin_lb_lib",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ std::string getHostAddress(const Host* host) {
} // namespace

ClientSideWeightedRoundRobinLbConfig::ClientSideWeightedRoundRobinLbConfig(
const ClientSideWeightedRoundRobinLbProto& lb_proto, Event::Dispatcher& main_thread_dispatcher)
: main_thread_dispatcher_(main_thread_dispatcher) {
const ClientSideWeightedRoundRobinLbProto& lb_proto, Event::Dispatcher& main_thread_dispatcher,
ThreadLocal::SlotAllocator& tls_slot_allocator)
: main_thread_dispatcher_(main_thread_dispatcher), tls_slot_allocator_(tls_slot_allocator) {
ENVOY_LOG_MISC(trace, "ClientSideWeightedRoundRobinLbConfig config {}", lb_proto.DebugString());
metric_names_for_computing_utilization =
std::vector<std::string>(lb_proto.metric_names_for_computing_utilization().begin(),
Expand All @@ -50,12 +51,18 @@ ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::WorkerLocalLb(
Runtime::Loader& runtime, Random::RandomGenerator& random,
const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config,
const ClientSideWeightedRoundRobinLbConfig& client_side_weighted_round_robin_config,
TimeSource& time_source)
TimeSource& time_source, OptRef<ThreadLocalShim> tls_shim)
: RoundRobinLoadBalancer(priority_set, local_priority_set, stats, runtime, random,
common_config,
/*round_robin_config=*/std::nullopt, time_source) {
orca_load_report_handler_ =
std::make_shared<OrcaLoadReportHandler>(client_side_weighted_round_robin_config, time_source);
if (tls_shim.has_value()) {
apply_weights_cb_handle_ = tls_shim->apply_weights_cb_helper_.add([this](uint32_t priority) {
refresh(priority);
return absl::OkStatus();
});
}
}

HostConstSharedPtr
Expand Down Expand Up @@ -107,13 +114,17 @@ void ClientSideWeightedRoundRobinLoadBalancer::startWeightUpdatesOnMainThread(
void ClientSideWeightedRoundRobinLoadBalancer::updateWeightsOnMainThread() {
ENVOY_LOG(trace, "updateWeightsOnMainThread");
for (const HostSetPtr& host_set : priority_set_.hostSetsPerPriority()) {
updateWeightsOnHosts(host_set->hosts());
if (updateWeightsOnHosts(host_set->hosts())) {
// If weights have changed, then apply them to all workers.
factory_->applyWeightsToAllWorkers(host_set->priority());
}
}
}

void ClientSideWeightedRoundRobinLoadBalancer::updateWeightsOnHosts(const HostVector& hosts) {
bool ClientSideWeightedRoundRobinLoadBalancer::updateWeightsOnHosts(const HostVector& hosts) {
std::vector<uint32_t> weights;
HostVector hosts_with_default_weight;
bool weights_updated = false;
const MonotonicTime now = time_source_.monotonicTime();
// Weight is considered invalid (too recent) if it was first updated within `blackout_period_`.
const MonotonicTime max_non_empty_since = now - blackout_period_;
Expand All @@ -132,28 +143,48 @@ void ClientSideWeightedRoundRobinLoadBalancer::updateWeightsOnHosts(const HostVe
// If `client_side_weight` is valid, then set it as the host weight and store it in
// `weights` to calculate median valid weight across all hosts.
if (client_side_weight.has_value()) {
weights.push_back(*client_side_weight);
host_ptr->weight(*client_side_weight);
ENVOY_LOG(trace, "updateWeights hostWeight {} = {}", getHostAddress(host_ptr.get()),
host_ptr->weight());
const uint32_t new_weight = client_side_weight.value();
weights.push_back(new_weight);
if (new_weight != host_ptr->weight()) {
host_ptr->weight(new_weight);
ENVOY_LOG(trace, "updateWeights hostWeight {} = {}", getHostAddress(host_ptr.get()),
host_ptr->weight());
weights_updated = true;
}
} else {
// If `client_side_weight` is invalid, then set host to default (median) weight.
hosts_with_default_weight.push_back(host_ptr);
}
}
// Calculate the default weight as median of all valid weights.
uint32_t default_weight = 1;
if (!weights.empty()) {
auto median_it = weights.begin() + weights.size() / 2;
std::nth_element(weights.begin(), median_it, weights.end());
default_weight = *median_it;
}
// Update the hosts with default weight.
for (const auto& host_ptr : hosts_with_default_weight) {
host_ptr->weight(default_weight);
ENVOY_LOG(trace, "updateWeights default hostWeight {} = {}", getHostAddress(host_ptr.get()),
host_ptr->weight());
// If some hosts don't have valid weight, then update them with default weight.
if (!hosts_with_default_weight.empty()) {
// Calculate the default weight as median of all valid weights.
uint32_t default_weight = 1;
if (!weights.empty()) {
const auto median_it = weights.begin() + weights.size() / 2;
std::nth_element(weights.begin(), median_it, weights.end());
if (weights.size() % 2 == 1) {
default_weight = *median_it;
} else {
// If the number of weights is even, then the median is the average of the two middle
// elements.
const auto lower_median_it = std::max_element(weights.begin(), median_it);
// Use uint64_t to avoid potential overflow of the weights sum.
default_weight = static_cast<uint32_t>(
(static_cast<uint64_t>(*lower_median_it) + static_cast<uint64_t>(*median_it)) / 2);
}
}
// Update the hosts with default weight.
for (const auto& host_ptr : hosts_with_default_weight) {
if (default_weight != host_ptr->weight()) {
host_ptr->weight(default_weight);
ENVOY_LOG(trace, "updateWeights default hostWeight {} = {}", getHostAddress(host_ptr.get()),
host_ptr->weight());
weights_updated = true;
}
}
}
return weights_updated;
}

void ClientSideWeightedRoundRobinLoadBalancer::addClientSideLbPolicyDataToHosts(
Expand Down Expand Up @@ -246,7 +277,16 @@ Upstream::LoadBalancerPtr ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalL
ASSERT(typed_lb_config != nullptr);
return std::make_unique<Upstream::ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb>(
params.priority_set, params.local_priority_set, cluster_info_.lbStats(), runtime_, random_,
cluster_info_.lbConfig(), *typed_lb_config, time_source_);
cluster_info_.lbConfig(), *typed_lb_config, time_source_, tls_->get());
}

void ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLbFactory::applyWeightsToAllWorkers(
uint32_t priority) {
tls_->runOnAllThreads([priority](OptRef<ThreadLocalShim> tls_shim) -> void {
if (tls_shim.has_value()) {
auto status = tls_shim->apply_weights_cb_helper_.runCallbacks(priority);
}
});
}

ClientSideWeightedRoundRobinLoadBalancer::ClientSideWeightedRoundRobinLoadBalancer(
Expand All @@ -265,8 +305,9 @@ absl::Status ClientSideWeightedRoundRobinLoadBalancer::initialize() {
}
// Setup a callback to receive priority set updates.
priority_update_cb_ = priority_set_.addPriorityUpdateCb(
[](uint32_t, const HostVector& hosts_added, const HostVector&) -> absl::Status {
[this](uint32_t, const HostVector& hosts_added, const HostVector&) -> absl::Status {
addClientSideLbPolicyDataToHosts(hosts_added);
updateWeightsOnMainThread();
return absl::OkStatus();
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include "envoy/extensions/load_balancing_policies/client_side_weighted_round_robin/v3/client_side_weighted_round_robin.pb.h"
#include "envoy/thread_local/thread_local.h"
#include "envoy/thread_local/thread_local_object.h"
#include "envoy/upstream/upstream.h"

#include "source/common/common/callback_impl.h"
#include "source/extensions/load_balancing_policies/common/load_balancer_impl.h"
#include "source/extensions/load_balancing_policies/round_robin/round_robin_lb.h"

Expand All @@ -21,7 +24,8 @@ using OrcaLoadReportProto = xds::data::orca::v3::OrcaLoadReport;
class ClientSideWeightedRoundRobinLbConfig : public Upstream::LoadBalancerConfig {
public:
ClientSideWeightedRoundRobinLbConfig(const ClientSideWeightedRoundRobinLbProto& lb_proto,
Event::Dispatcher& main_thread_dispatcher);
Event::Dispatcher& main_thread_dispatcher,
ThreadLocal::SlotAllocator& tls_slot_allocator);

// Parameters for weight calculation from Orca Load report.
std::vector<std::string> metric_names_for_computing_utilization;
Expand All @@ -32,6 +36,7 @@ class ClientSideWeightedRoundRobinLbConfig : public Upstream::LoadBalancerConfig
std::chrono::milliseconds weight_update_period;

Event::Dispatcher& main_thread_dispatcher_;
ThreadLocal::SlotAllocator& tls_slot_allocator_;
};

/**
Expand Down Expand Up @@ -131,6 +136,12 @@ class ClientSideWeightedRoundRobinLoadBalancer : public Upstream::ThreadAwareLoa
TimeSource& time_source_;
};

// Thread local shim to store callbacks for weight updates of worker local lb.
class ThreadLocalShim : public Envoy::ThreadLocal::ThreadLocalObject {
public:
Common::CallbackManager<uint32_t> apply_weights_cb_helper_;
};

// This class is used to handle the load balancing on the worker thread.
class WorkerLocalLb : public RoundRobinLoadBalancer {
public:
Expand All @@ -139,15 +150,15 @@ class ClientSideWeightedRoundRobinLoadBalancer : public Upstream::ThreadAwareLoa
ClusterLbStats& stats, Runtime::Loader& runtime, Random::RandomGenerator& random,
const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config,
const ClientSideWeightedRoundRobinLbConfig& client_side_weighted_round_robin_config,
TimeSource& time_source);
TimeSource& time_source, OptRef<ThreadLocalShim> tls_shim);

private:
friend class ClientSideWeightedRoundRobinLoadBalancerFriend;

HostConstSharedPtr chooseHost(LoadBalancerContext* context) override;
bool alwaysUseEdfScheduler() const override { return true; };

std::shared_ptr<OrcaLoadReportHandler> orca_load_report_handler_;
Common::CallbackHandlePtr apply_weights_cb_handle_;
};

// Factory used to create worker-local load balancer on the worker thread.
Expand All @@ -158,14 +169,24 @@ class ClientSideWeightedRoundRobinLoadBalancer : public Upstream::ThreadAwareLoa
const Upstream::PrioritySet& priority_set, Runtime::Loader& runtime,
Envoy::Random::RandomGenerator& random, TimeSource& time_source)
: lb_config_(lb_config), cluster_info_(cluster_info), priority_set_(priority_set),
runtime_(runtime), random_(random), time_source_(time_source) {}
runtime_(runtime), random_(random), time_source_(time_source) {
const auto* typed_lb_config =
dynamic_cast<const ClientSideWeightedRoundRobinLbConfig*>(lb_config.ptr());
ASSERT(typed_lb_config != nullptr);
tls_ =
ThreadLocal::TypedSlot<ThreadLocalShim>::makeUnique(typed_lb_config->tls_slot_allocator_);
tls_->set([](Envoy::Event::Dispatcher&) { return std::make_shared<ThreadLocalShim>(); });
}

Upstream::LoadBalancerPtr create(Upstream::LoadBalancerParams params) override;

bool recreateOnHostChange() const override { return false; }

void applyWeightsToAllWorkers(uint32_t priority);

protected:
OptRef<const Upstream::LoadBalancerConfig> lb_config_;
std::unique_ptr<Envoy::ThreadLocal::TypedSlot<ThreadLocalShim>> tls_;

const Upstream::ClusterInfo& cluster_info_;
const Upstream::PrioritySet& priority_set_;
Expand Down Expand Up @@ -200,7 +221,8 @@ class ClientSideWeightedRoundRobinLoadBalancer : public Upstream::ThreadAwareLoa
void updateWeightsOnMainThread();

// Update weights using client side host LB policy data for all `hosts`.
void updateWeightsOnHosts(const HostVector& hosts);
// Returns true if any host weight is updated.
bool updateWeightsOnHosts(const HostVector& hosts);

// Add client side host LB policy data to all `hosts`.
static void addClientSideLbPolicyDataToHosts(const HostVector& hosts);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ namespace ClientSideWeightedRoundRobin {

using ClientSideWeightedRoundRobinLbProto = envoy::extensions::load_balancing_policies::
client_side_weighted_round_robin::v3::ClientSideWeightedRoundRobin;
// using ClusterProto = envoy::config::cluster::v3::Cluster;

class Factory : public Upstream::TypedLoadBalancerFactoryBase<ClientSideWeightedRoundRobinLbProto> {
public:
Expand All @@ -38,7 +37,7 @@ class Factory : public Upstream::TypedLoadBalancerFactoryBase<ClientSideWeighted
const Protobuf::Message& config) override {
const auto& lb_config = dynamic_cast<const ClientSideWeightedRoundRobinLbProto&>(config);
return Upstream::LoadBalancerConfigPtr{new Upstream::ClientSideWeightedRoundRobinLbConfig(
lb_config, context.mainThreadDispatcher())};
lb_config, context.mainThreadDispatcher(), context.threadLocal())};
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ void EdfLoadBalancerBase::refresh(uint32_t priority) {
// case EDF creation is skipped. When all original weights are equal and no hosts are in slow
// start mode we can rely on unweighted host pick to do optimal round robin and least-loaded
// host selection with lower memory and CPU overhead.
if (!alwaysUseEdfScheduler() && hostWeightsAreEqual(hosts) && noHostsAreInSlowStart()) {
if (hostWeightsAreEqual(hosts) && noHostsAreInSlowStart()) {
// Skip edf creation.
return;
}
Expand Down Expand Up @@ -963,8 +963,6 @@ void EdfLoadBalancerBase::refresh(uint32_t priority) {
}
}

bool EdfLoadBalancerBase::alwaysUseEdfScheduler() const { return false; }

bool EdfLoadBalancerBase::isSlowStartEnabled() const {
return slow_start_window_ > std::chrono::milliseconds(0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,6 @@ class EdfLoadBalancerBase : public ZoneAwareLoadBalancerBase {

virtual void refresh(uint32_t priority);

// Return `true` if refresh() should always use EDF scheduler, even if host
// weights are all equal. Default to `false`.
virtual bool alwaysUseEdfScheduler() const;

bool isSlowStartEnabled() const;
bool noHostsAreInSlowStart() const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,13 @@ class ClientSideWeightedRoundRobinLoadBalancerTest : public LoadBalancerTestBase
client_side_weighted_round_robin_config_.mutable_metric_names_for_computing_utilization()->Add(
"metric2");

EXPECT_CALL(mock_tls_, allocateSlot());
lb_ = std::make_shared<ClientSideWeightedRoundRobinLoadBalancerFriend>(
std::make_shared<ClientSideWeightedRoundRobinLoadBalancer>(
lb_config_, cluster_info_, priority_set_, runtime_, random_, simTime()),
std::make_shared<ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb>(
priority_set_, local_priority_set_.get(), stats_, runtime_, random_, common_config_,
lb_config_, simTime()));
lb_config_, simTime(), /*tls_shim=*/absl::nullopt));

// Initialize the thread aware load balancer from config.
ASSERT_EQ(lb_->initialize(), absl::OkStatus());
Expand Down Expand Up @@ -161,9 +162,10 @@ class ClientSideWeightedRoundRobinLoadBalancerTest : public LoadBalancerTestBase

NiceMock<MockLoadBalancerContext> lb_context_;
NiceMock<Event::MockDispatcher> dispatcher_;
NiceMock<Envoy::ThreadLocal::MockInstance> mock_tls_;
NiceMock<MockClusterInfo> cluster_info_;
ClientSideWeightedRoundRobinLbConfig lb_config_ =
ClientSideWeightedRoundRobinLbConfig(client_side_weighted_round_robin_config_, dispatcher_);
ClientSideWeightedRoundRobinLbConfig lb_config_ = ClientSideWeightedRoundRobinLbConfig(
client_side_weighted_round_robin_config_, dispatcher_, mock_tls_);
};

//////////////////////////////////////////////////////
Expand Down Expand Up @@ -216,7 +218,7 @@ TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, UpdateWeightsOneHostHasClie
EXPECT_EQ(hosts[2]->weight(), 42);
}

TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, UpdateWeightsDefaultIsMedianWeight) {
TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, UpdateWeightsDefaultIsOddMedianWeight) {
init(false);
HostVector hosts = {
makeTestHost(info_, "tcp://127.0.0.1:80", simTime()),
Expand Down Expand Up @@ -247,6 +249,36 @@ TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, UpdateWeightsDefaultIsMedia
EXPECT_EQ(hosts[4]->weight(), 42);
}

TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, UpdateWeightsDefaultIsEvenMedianWeight) {
init(false);
HostVector hosts = {
makeTestHost(info_, "tcp://127.0.0.1:80", simTime()),
makeTestHost(info_, "tcp://127.0.0.1:81", simTime()),
makeTestHost(info_, "tcp://127.0.0.1:82", simTime()),
makeTestHost(info_, "tcp://127.0.0.1:83", simTime()),
makeTestHost(info_, "tcp://127.0.0.1:84", simTime()),
};
simTime().setMonotonicTime(MonotonicTime(std::chrono::seconds(30)));
// Set client side weight for first two hosts.
setHostClientSideWeight(hosts[0], 5, 5, 10);
setHostClientSideWeight(hosts[1], 42, 5, 10);
// Setting client side weights should not change the host weights.
EXPECT_EQ(hosts[0]->weight(), 1);
EXPECT_EQ(hosts[1]->weight(), 1);
EXPECT_EQ(hosts[2]->weight(), 1);
EXPECT_EQ(hosts[3]->weight(), 1);
EXPECT_EQ(hosts[4]->weight(), 1);
// Update weights on hosts.
lb_->updateWeightsOnHosts(hosts);
// First two hosts have client side weight, other hosts get the median
// weight which is average of weights of first two hosts.
EXPECT_EQ(hosts[0]->weight(), 5);
EXPECT_EQ(hosts[1]->weight(), 42);
EXPECT_EQ(hosts[2]->weight(), 23);
EXPECT_EQ(hosts[3]->weight(), 23);
EXPECT_EQ(hosts[4]->weight(), 23);
}

TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, ChooseHostWithClientSideWeights) {
if (&hostSet() == &failover_host_set_) { // P = 1 does not support zone-aware routing.
return;
Expand Down
Loading

0 comments on commit 150e16d

Please sign in to comment.