From dbb653e049aca74e3cceec367968d5c7a571e3a5 Mon Sep 17 00:00:00 2001 From: Carson Ip Date: Fri, 2 Aug 2024 16:57:01 +0100 Subject: [PATCH] Fix global labels values race condition (#186) Fix a race condition for global labels that use a slice of values instead of a single value. --- aggregators/codec.go | 14 ++++-- aggregators/converter.go | 83 ++++++++++++++++------------------- aggregators/converter_test.go | 31 ++++++++++++- 3 files changed, 77 insertions(+), 51 deletions(-) diff --git a/aggregators/codec.go b/aggregators/codec.go index dcf412d..7972768 100644 --- a/aggregators/codec.go +++ b/aggregators/codec.go @@ -330,7 +330,8 @@ func (gl *globalLabels) ToProto() *aggregationpb.GlobalLabels { } pb.Labels[i].Key = k pb.Labels[i].Value = v.Value - pb.Labels[i].Values = v.Values + pb.Labels[i].Values = slices.Grow(pb.Labels[i].Values, len(v.Values))[:len(v.Values)] + copy(pb.Labels[i].Values, v.Values) i++ } sort.Slice(pb.Labels, func(i, j int) bool { @@ -345,7 +346,8 @@ func (gl *globalLabels) ToProto() *aggregationpb.GlobalLabels { } pb.NumericLabels[i].Key = k pb.NumericLabels[i].Value = v.Value - pb.NumericLabels[i].Values = v.Values + pb.NumericLabels[i].Values = slices.Grow(pb.NumericLabels[i].Values, len(v.Values))[:len(v.Values)] + copy(pb.NumericLabels[i].Values, v.Values) i++ } sort.Slice(pb.NumericLabels, func(i, j int) bool { @@ -359,11 +361,15 @@ func (gl *globalLabels) ToProto() *aggregationpb.GlobalLabels { func (gl *globalLabels) FromProto(pb *aggregationpb.GlobalLabels) { gl.Labels = make(modelpb.Labels, len(pb.Labels)) for _, l := range pb.Labels { - gl.Labels[l.Key] = &modelpb.LabelValue{Value: l.Value, Values: l.Values, Global: true} + gl.Labels[l.Key] = &modelpb.LabelValue{Value: l.Value, Global: true} + gl.Labels[l.Key].Values = slices.Grow(gl.Labels[l.Key].Values, len(l.Values))[:len(l.Values)] + copy(gl.Labels[l.Key].Values, l.Values) } gl.NumericLabels = make(modelpb.NumericLabels, len(pb.NumericLabels)) for _, l := range pb.NumericLabels { - gl.NumericLabels[l.Key] = &modelpb.NumericLabelValue{Value: l.Value, Values: l.Values, Global: true} + gl.NumericLabels[l.Key] = &modelpb.NumericLabelValue{Value: l.Value, Global: true} + gl.NumericLabels[l.Key].Values = slices.Grow(gl.NumericLabels[l.Key].Values, len(l.Values))[:len(l.Values)] + copy(gl.NumericLabels[l.Key].Values, l.Values) } } diff --git a/aggregators/converter.go b/aggregators/converter.go index 2fdb6ce..345e311 100644 --- a/aggregators/converter.go +++ b/aggregators/converter.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "math" + "slices" "sort" "sync" "time" @@ -984,75 +985,67 @@ func overflowSpanMetricsToAPMEvent( } func marshalEventGlobalLabels(e *modelpb.APMEvent) ([]byte, error) { - if len(e.Labels) == 0 && len(e.NumericLabels) == 0 { + var labelsCnt, numericLabelsCnt int + for _, v := range e.Labels { + if !v.Global { + continue + } + labelsCnt++ + } + for _, v := range e.NumericLabels { + if !v.Global { + continue + } + numericLabelsCnt++ + } + + if labelsCnt == 0 && numericLabelsCnt == 0 { return nil, nil } - var pb *aggregationpb.GlobalLabels + pb := aggregationpb.GlobalLabelsFromVTPool() + defer pb.ReturnToVTPool() + pb.Labels = slices.Grow(pb.Labels, labelsCnt)[:labelsCnt] + pb.NumericLabels = slices.Grow(pb.NumericLabels, numericLabelsCnt)[:numericLabelsCnt] + var i int // Keys must be sorted to ensure wire formats are deterministically generated and strings are directly comparable // i.e. Protobuf formats are equal if and only if the structs are equal for k, v := range e.Labels { if !v.Global { continue } - - if pb == nil { - pb = aggregationpb.GlobalLabelsFromVTPool() - defer pb.ReturnToVTPool() - } - - i := len(pb.Labels) - if i == cap(pb.Labels) { - pb.Labels = append(pb.Labels, &aggregationpb.Label{}) - } else { - pb.Labels = pb.Labels[:i+1] - if pb.Labels[i] == nil { - pb.Labels[i] = &aggregationpb.Label{} - } + if pb.Labels[i] == nil { + pb.Labels[i] = &aggregationpb.Label{} } pb.Labels[i].Key = k pb.Labels[i].Value = v.Value - pb.Labels[i].Values = v.Values - } - if pb != nil { - sort.Slice(pb.Labels, func(i, j int) bool { - return pb.Labels[i].Key < pb.Labels[j].Key - }) + pb.Labels[i].Values = slices.Grow(pb.Labels[i].Values, len(v.Values))[:len(v.Values)] + copy(pb.Labels[i].Values, v.Values) + i++ } + sort.Slice(pb.Labels, func(i, j int) bool { + return pb.Labels[i].Key < pb.Labels[j].Key + }) + i = 0 for k, v := range e.NumericLabels { if !v.Global { continue } - - if pb == nil { - pb = aggregationpb.GlobalLabelsFromVTPool() - defer pb.ReturnToVTPool() - } - - i := len(pb.NumericLabels) - if i == cap(pb.NumericLabels) { - pb.NumericLabels = append(pb.NumericLabels, &aggregationpb.NumericLabel{}) - } else { - pb.NumericLabels = pb.NumericLabels[:i+1] - if pb.NumericLabels[i] == nil { - pb.NumericLabels[i] = &aggregationpb.NumericLabel{} - } + if pb.NumericLabels[i] == nil { + pb.NumericLabels[i] = &aggregationpb.NumericLabel{} } pb.NumericLabels[i].Key = k pb.NumericLabels[i].Value = v.Value - pb.NumericLabels[i].Values = v.Values - } - if pb != nil { - sort.Slice(pb.NumericLabels, func(i, j int) bool { - return pb.NumericLabels[i].Key < pb.NumericLabels[j].Key - }) + pb.NumericLabels[i].Values = slices.Grow(pb.NumericLabels[i].Values, len(v.Values))[:len(v.Values)] + copy(pb.NumericLabels[i].Values, v.Values) + i++ } + sort.Slice(pb.NumericLabels, func(i, j int) bool { + return pb.NumericLabels[i].Key < pb.NumericLabels[j].Key + }) - if pb == nil { - return nil, nil - } return pb.MarshalVT() } diff --git a/aggregators/converter_test.go b/aggregators/converter_test.go index d109485..4bc1cf4 100644 --- a/aggregators/converter_test.go +++ b/aggregators/converter_test.go @@ -7,6 +7,7 @@ package aggregators import ( "fmt" "net/netip" + "sync" "testing" "time" @@ -779,8 +780,8 @@ func getTestGlobalLabelsStr(t *testing.T, s string) string { return gls } -func TestMarshalEventGlobalLabels(t *testing.T) { - e := &modelpb.APMEvent{ +func globalLabelsEvent() *modelpb.APMEvent { + return &modelpb.APMEvent{ Labels: modelpb.Labels{ "tag1": &modelpb.LabelValue{ Value: "1", @@ -826,6 +827,10 @@ func TestMarshalEventGlobalLabels(t *testing.T) { }, }, } +} + +func TestMarshalEventGlobalLabels(t *testing.T) { + e := globalLabelsEvent() b, err := marshalEventGlobalLabels(e) require.NoError(t, err) gl := globalLabels{} @@ -856,3 +861,25 @@ func TestMarshalEventGlobalLabels(t *testing.T) { }, }, gl.NumericLabels) } + +func TestMarshalEventGlobalLabelsRace(t *testing.T) { + const N = 1000 + wg := sync.WaitGroup{} + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + e := globalLabelsEvent() + b, err := marshalEventGlobalLabels(e) + require.NoError(t, err) + gl := globalLabels{} + err = gl.UnmarshalBinary(b) + require.NoError(t, err) + b, err = gl.MarshalBinary() + require.NoError(t, err) + err = gl.UnmarshalBinary(b) + require.NoError(t, err) + wg.Done() + }() + } + wg.Wait() +}