Skip to content

Commit

Permalink
Disable AVX512 support until proper testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ogxd committed Sep 6, 2024
1 parent 90b7fad commit 5b6ffb2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Equativ.RoaringBitmaps/PopcntAvx2.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal static ulong Popcnt(ReadOnlySpan<ulong> longs)
ref Vector256<long> start = ref Unsafe.As<ulong, Vector256<long>>(ref MemoryMarshal.GetReference(longs));
ulong total = Popcnt(ref start, longs.Length / 4);

// Handle remaining bytes
// Handle remaining longs
for (int i = longs.Length - longs.Length % 4; i < longs.Length; i++)
{
total += (ulong)BitOperations.PopCount(longs[i]);
Expand Down
75 changes: 30 additions & 45 deletions Equativ.RoaringBitmaps/PopcntAvx512.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,91 +11,76 @@ namespace Equativ.RoaringBitmaps;
internal static class PopcntAvx512
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static ulong Popcnt(ReadOnlySpan<byte> bytes)
internal static ulong Popcnt(ReadOnlySpan<ulong> longs)
{
ulong total = Popcnt(MemoryMarshal.Cast<byte, Vector512<uint>>(bytes), bytes.Length / 64);
ref Vector512<uint> start = ref Unsafe.As<ulong, Vector512<uint>>(ref MemoryMarshal.GetReference(longs));
ulong total = Popcnt(ref start, longs.Length / 8);

// Handle remaining bytes
for (int i = bytes.Length - bytes.Length % 64; i < bytes.Length; i++)
// Handle remaining longs
for (int i = longs.Length - longs.Length % 4; i < longs.Length; i++)
{
total += (ulong)BitOperations.PopCount(bytes[i]);
total += (ulong)BitOperations.PopCount(longs[i]);
}

return total;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector512<ushort> Popcnt(Vector512<byte> v)
private static Vector512<ushort> PopcntVec(ref Vector512<uint> v)
{
Vector512<byte> m1 = Vector512.Create((byte)0x55);
Vector512<byte> m2 = Vector512.Create((byte)0x33);
Vector512<byte> m4 = Vector512.Create((byte)0x0F);

Vector512<byte> t1 = Avx512BW.Subtract(v, Avx512F.And(Avx512BW.ShiftRightLogical(v.AsUInt16(), 1).AsByte(), m1));
Vector512<byte> t1 = Avx512BW.Subtract(v.AsByte(), Avx512F.And(Avx512BW.ShiftRightLogical(v.AsUInt16(), 1).AsByte(), m1));
Vector512<byte> t2 = Avx512BW.Add(Avx512F.And(t1, m2), Avx512F.And(Avx512BW.ShiftRightLogical(t1.AsUInt16(), 2).AsByte(), m2));
Vector512<byte> t3 = Avx512F.And(Avx512BW.Add(t2, Avx512BW.ShiftRightLogical(t2.AsUInt16(), 4).AsByte()), m4);
return Avx512BW.SumAbsoluteDifferences(t3, Vector512<byte>.Zero);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void CSA(out Vector512<uint> h, out Vector512<uint> l, Vector512<uint> a, Vector512<uint> b, Vector512<uint> c)
private static void CSA(out Vector512<uint> h, out Vector512<uint> l, ref Vector512<uint> a, ref Vector512<uint> b, ref Vector512<uint> c)
{
l = Avx512F.TernaryLogic(c, b, a, 0x96);
h = Avx512F.TernaryLogic(c, b, a, 0xe8);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ulong Popcnt(ReadOnlySpan<Vector512<uint>> data, int size)
private static ulong Popcnt(ref Vector512<uint> start, int size)
{
Vector512<ulong> total = Vector512<ulong>.Zero;
Vector512<uint> ones = Vector512<uint>.Zero;
Vector512<uint> twos = Vector512<uint>.Zero;
Vector512<uint> fours = Vector512<uint>.Zero;
Vector512<uint> eights = Vector512<uint>.Zero;
Vector512<uint> sixteens = Vector512<uint>.Zero;
Vector512<uint> twosA, twosB, foursA, foursB, eightsA, eightsB;

int limit = size - size % 16;
int i = 0;
int limit = size - size % 4;

for (; i < limit; i += 16)
{
CSA(out twosA, out ones, ones, data[i + 0], data[i + 1]);
CSA(out twosB, out ones, ones, data[i + 2], data[i + 3]);
CSA(out foursA, out twos, twos, twosA, twosB);
CSA(out twosA, out ones, ones, data[i + 4], data[i + 5]);
CSA(out twosB, out ones, ones, data[i + 6], data[i + 7]);
CSA(out foursB, out twos, twos, twosA, twosB);
CSA(out eightsA, out fours, fours, foursA, foursB);
CSA(out twosA, out ones, ones, data[i + 8], data[i + 9]);
CSA(out twosB, out ones, ones, data[i + 10], data[i + 11]);
CSA(out foursA, out twos, twos, twosA, twosB);
CSA(out twosA, out ones, ones, data[i + 12], data[i + 13]);
CSA(out twosB, out ones, ones, data[i + 14], data[i + 15]);
CSA(out foursB, out twos, twos, twosA, twosB);
CSA(out eightsB, out fours, fours, foursA, foursB);
CSA(out sixteens, out eights, eights, eightsA, eightsB);
ref var end = ref Unsafe.Add(ref start, limit);

total = Avx512F.Add(total, Popcnt(sixteens.AsByte()).AsUInt64());
while (Unsafe.IsAddressLessThan(ref start, ref end))
{
CSA(out var twosA, out ones, ref ones, ref start, ref Unsafe.Add(ref start, 1));
CSA(out var twosB, out ones, ref ones, ref Unsafe.Add(ref start, 2), ref Unsafe.Add(ref start, 3));
CSA(out var fours, out twos, ref twos, ref twosA, ref twosB);

total = Avx512F.Add(total, PopcntVec(ref fours).AsUInt64());

start = ref Unsafe.Add(ref start, 4);
}

total = Avx512F.ShiftLeftLogical(total, 4);
total = Avx512F.Add(total, Avx512F.ShiftLeftLogical(Popcnt(eights.AsByte()).AsUInt64(), 3));
total = Avx512F.Add(total, Avx512F.ShiftLeftLogical(Popcnt(fours.AsByte()).AsUInt64(), 2));
total = Avx512F.Add(total, Avx512F.ShiftLeftLogical(Popcnt(twos.AsByte()).AsUInt64(), 1));
total = Avx512F.Add(total, Popcnt(ones.AsByte()).AsUInt64());
total = Avx512F.Add(total, Avx512F.ShiftLeftLogical(PopcntVec(ref twos).AsUInt64(), 1));
total = Avx512F.Add(total, PopcntVec(ref ones).AsUInt64());

ref var end2 = ref Unsafe.Add(ref start, size % 4);

for (; i < size; i++)
// Handle remaining vectors
while (Unsafe.IsAddressLessThan(ref start, ref end2))
{
total = Avx512F.Add(total, Popcnt(data[i].AsByte()).AsUInt64());
total = Avx512F.Add(total, PopcntVec(ref start).AsUInt64());
start = ref Unsafe.Add(ref start, 1);
}

return SimdSumEpu64(total);
}

private static ulong SimdSumEpu64(Vector512<ulong> v)
{
Vector256<ulong> sum256 = Avx2.Add(v.GetLower(), v.GetUpper());
Vector256<ulong> sum256 = Avx2.Add(total.GetLower(), total.GetUpper());
Vector128<ulong> sum128 = Sse2.Add(sum256.GetLower(), sum256.GetUpper());
return sum128.GetElement(0) + sum128.GetElement(1);
}
Expand Down
11 changes: 5 additions & 6 deletions Equativ.RoaringBitmaps/Utils.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;

Expand All @@ -19,14 +17,15 @@ internal static int Popcnt(ulong[] longs)
{
return (int)PopcntNeon.Popcnt(longs.AsSpan());
}
if (Avx512BW.IsSupported)
{
return (int)PopcntAvx512.Popcnt(MemoryMarshal.Cast<ulong, byte>(longs));
}
if (Avx2.IsSupported)
{
return (int)PopcntAvx2.Popcnt(longs.AsSpan());
}
// AVX512 Support needs proper testing before being enabled
// if (Avx512BW.IsSupported)
// {
// return (int)PopcntAvx512.Popcnt(longs.AsSpan());
// }

return Popcnt64.Popcnt(longs);
}
Expand Down

0 comments on commit 5b6ffb2

Please sign in to comment.