Skip to content

Commit

Permalink
SSE2/ARM code updated
Browse files Browse the repository at this point in the history
  • Loading branch information
oleg-st committed Apr 6, 2024
1 parent c1171b8 commit 106a59a
Showing 1 changed file with 90 additions and 23 deletions.
113 changes: 90 additions & 23 deletions src/ZstdSharp/Unsafe/ZstdLazy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -940,31 +940,65 @@ private static uint ZSTD_row_matchMaskGroupWidth(uint rowEntries)
{
assert(rowEntries == 16 || rowEntries == 32 || rowEntries == 64);
assert(rowEntries <= 64);
#if NET5_0_OR_GREATER
if (AdvSimd.IsSupported && BitConverter.IsLittleEndian)
{
if (rowEntries == 16)
return 4;
#if NET9_0_OR_GREATER
if (AdvSimd.Arm64.IsSupported)
{
if (rowEntries == 32)
return 2;
if (rowEntries == 64)
return 1;
}
#endif
}
#endif
return 1;
}

#if NETCOREAPP3_0_OR_GREATER
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ulong ZSTD_row_getSSEMask(int nbChunks, byte* src, byte tag, uint head)
{
Vector128<sbyte> comparisonMask = Vector128.Create((sbyte)tag);
int* matches = stackalloc int[4];
memset(matches, 0, sizeof(int) * 4);
int i;
assert(nbChunks == 1 || nbChunks == 2 || nbChunks == 4);
for (i = 0; i < nbChunks; i++)
Vector128<byte> comparisonMask = Vector128.Create(tag);
assert(nbChunks is 1 or 2 or 4);
if (nbChunks == 1)
{
Vector128<sbyte> chunk = Sse2.LoadVector128((sbyte*)(Vector128<sbyte>*)(void*)(src + 16 * i));
Vector128<sbyte> equalMask = Sse2.CompareEqual(chunk, comparisonMask);
matches[i] = Sse2.MoveMask(equalMask);
Vector128<byte> chunk0 = Sse2.LoadVector128(src);
Vector128<byte> equalMask0 = Sse2.CompareEqual(chunk0, comparisonMask);
int matches0 = Sse2.MoveMask(equalMask0);
return BitOperations.RotateRight((ushort)matches0, (int)head);
}

if (nbChunks == 1)
return BitOperations.RotateRight((ushort)matches[0], (int)head);
if (nbChunks == 2)
return BitOperations.RotateRight((uint)matches[1] << 16 | (uint)matches[0], (int)head);
assert(nbChunks == 4);
return BitOperations.RotateRight((ulong)matches[3] << 48 | (ulong)matches[2] << 32 | (ulong)matches[1] << 16 | (ulong)matches[0], (int)head);
{
Vector128<byte> chunk0 = Sse2.LoadVector128(src);
Vector128<byte> equalMask0 = Sse2.CompareEqual(chunk0, comparisonMask);
int matches0 = Sse2.MoveMask(equalMask0);
Vector128<byte> chunk1 = Sse2.LoadVector128(src + 16);
Vector128<byte> equalMask1 = Sse2.CompareEqual(chunk1, comparisonMask);
int matches1 = Sse2.MoveMask(equalMask1);
return BitOperations.RotateRight((uint)matches1 << 16 | (uint)matches0, (int)head);
}

{
Vector128<byte> chunk0 = Sse2.LoadVector128(src);
Vector128<byte> equalMask0 = Sse2.CompareEqual(chunk0, comparisonMask);
int matches0 = Sse2.MoveMask(equalMask0);
Vector128<byte> chunk1 = Sse2.LoadVector128(src + 16 * 1);
Vector128<byte> equalMask1 = Sse2.CompareEqual(chunk1, comparisonMask);
int matches1 = Sse2.MoveMask(equalMask1);
Vector128<byte> chunk2 = Sse2.LoadVector128(src + 16 * 2);
Vector128<byte> equalMask2 = Sse2.CompareEqual(chunk2, comparisonMask);
int matches2 = Sse2.MoveMask(equalMask2);
Vector128<byte> chunk3 = Sse2.LoadVector128(src + 16 * 3);
Vector128<byte> equalMask3 = Sse2.CompareEqual(chunk3, comparisonMask);
int matches3 = Sse2.MoveMask(equalMask3);
return BitOperations.RotateRight((ulong)matches3 << 48 | (ulong)matches2 << 32 | (ulong)matches1 << 16 | (uint)matches0, (int)head);
}
}
#endif

Expand Down Expand Up @@ -993,23 +1027,56 @@ private static ulong ZSTD_row_getMatchMask(byte* tagRow, byte tag, uint headGrou
{
if (rowEntries == 16)
{
/* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits.
* After that groups of 4 bits represent the equalMask. We lower
* all bits except the highest in these groups by doing AND with
* 0x88 = 0b10001000.
*/
Vector128<byte> chunk = AdvSimd.LoadVector128(src);
Vector128<ushort> equalMask = AdvSimd.CompareEqual(chunk, AdvSimd.DuplicateToVector128(tag)).As<byte, ushort>();
Vector128<ushort> t0 = AdvSimd.ShiftLeftLogical(equalMask, 7);
Vector128<uint> t1 = AdvSimd.ShiftRightAndInsert(t0, t0, 14).As<ushort, uint>();
Vector128<ulong> t2 = AdvSimd.ShiftRightLogical(t1, 14).As<uint, ulong>();
Vector128<byte> t3 = AdvSimd.ShiftRightLogicalAdd(t2, t2, 28).As<ulong, byte>();
ushort hi = AdvSimd.Extract(t3, 8);
ushort lo = AdvSimd.Extract(t3, 0);
return BitOperations.RotateRight((ushort)(hi << 8 | lo), (int)headGrouped);
Vector64<byte> res = AdvSimd.ShiftRightLogicalNarrowingLower(equalMask, 4);
ulong matches = res.As<byte, ulong>().GetElement(0);
return BitOperations.RotateRight(matches, (int)headGrouped) & 0x8888888888888888;
}
else if (rowEntries == 32)
{
// todo, there is no vld2q_u16 in c#
#if NET9_0_OR_GREATER
if (AdvSimd.Arm64.IsSupported)
{
/* Same idea as with rowEntries == 16 but doing AND with
* 0x55 = 0b01010101.
*/
(Vector128<ushort> chunk0, Vector128<ushort> chunk1) = AdvSimd.Arm64.LoadVector128x2AndUnzip((ushort*)src);
Vector128<byte> dup = AdvSimd.DuplicateToVector128(tag);
Vector64<byte> t0 = AdvSimd.ShiftRightLogicalNarrowingLower(AdvSimd.CompareEqual(chunk0.As<ushort, byte>(), dup).As<byte, ushort>(), 6);
Vector64<byte> t1 = AdvSimd.ShiftRightLogicalNarrowingLower(AdvSimd.CompareEqual(chunk1.As<ushort, byte>(), dup).As<byte, ushort>(), 6);
Vector64<byte> res = AdvSimd.ShiftLeftAndInsert(t0, t1, 4);
ulong matches = res.As<byte, ulong>().GetElement(0);
return BitOperations.RotateRight(matches, (int)headGrouped) & 0x5555555555555555;
}
#endif
}
else
{ /* rowEntries == 64 */
// todo, there is no vld4q_u8 in c#
#if NET9_0_OR_GREATER
if (AdvSimd.Arm64.IsSupported)
{
(Vector128<byte> chunk0, Vector128<byte> chunk1, Vector128<byte> chunk2, Vector128<byte> chunk3) = AdvSimd.Arm64.LoadVector128x4AndUnzip(src);
Vector128<byte> dup = AdvSimd.DuplicateToVector128(tag);
Vector128<byte> cmp0 = AdvSimd.CompareEqual(chunk0, dup);
Vector128<byte> cmp1 = AdvSimd.CompareEqual(chunk1, dup);
Vector128<byte> cmp2 = AdvSimd.CompareEqual(chunk2, dup);
Vector128<byte> cmp3 = AdvSimd.CompareEqual(chunk3, dup);

Vector128<byte> t0 = AdvSimd.ShiftRightAndInsert(cmp1, cmp0, 1);
Vector128<byte> t1 = AdvSimd.ShiftRightAndInsert(cmp3, cmp2, 1);
Vector128<byte> t2 = AdvSimd.ShiftRightAndInsert(t1, t0, 2);
Vector128<byte> t3 = AdvSimd.ShiftRightAndInsert(t2, t2, 4);
Vector64<byte> t4 = AdvSimd.ShiftRightLogicalNarrowingLower(t3.As<byte, ushort>(), 4);
ulong matches = t4.As<byte, ulong>().GetElement(0);
return BitOperations.RotateRight(matches, (int) headGrouped);
}
#endif
}
}
#endif
Expand Down

0 comments on commit 106a59a

Please sign in to comment.