From 106a59ac8dbaa91016852681fb6cf20548cd5c11 Mon Sep 17 00:00:00 2001 From: Oleg Stepanischev Date: Sat, 6 Apr 2024 22:14:47 +0300 Subject: [PATCH] SSE2/ARM code updated --- src/ZstdSharp/Unsafe/ZstdLazy.cs | 113 ++++++++++++++++++++++++------- 1 file changed, 90 insertions(+), 23 deletions(-) diff --git a/src/ZstdSharp/Unsafe/ZstdLazy.cs b/src/ZstdSharp/Unsafe/ZstdLazy.cs index 98010b4..d163f3d 100644 --- a/src/ZstdSharp/Unsafe/ZstdLazy.cs +++ b/src/ZstdSharp/Unsafe/ZstdLazy.cs @@ -940,6 +940,22 @@ private static uint ZSTD_row_matchMaskGroupWidth(uint rowEntries) { assert(rowEntries == 16 || rowEntries == 32 || rowEntries == 64); assert(rowEntries <= 64); +#if NET5_0_OR_GREATER + if (AdvSimd.IsSupported && BitConverter.IsLittleEndian) + { + if (rowEntries == 16) + return 4; +#if NET9_0_OR_GREATER + if (AdvSimd.Arm64.IsSupported) + { + if (rowEntries == 32) + return 2; + if (rowEntries == 64) + return 1; + } +#endif + } +#endif return 1; } @@ -947,24 +963,42 @@ private static uint ZSTD_row_matchMaskGroupWidth(uint rowEntries) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static ulong ZSTD_row_getSSEMask(int nbChunks, byte* src, byte tag, uint head) { - Vector128 comparisonMask = Vector128.Create((sbyte)tag); - int* matches = stackalloc int[4]; - memset(matches, 0, sizeof(int) * 4); - int i; - assert(nbChunks == 1 || nbChunks == 2 || nbChunks == 4); - for (i = 0; i < nbChunks; i++) + Vector128 comparisonMask = Vector128.Create(tag); + assert(nbChunks is 1 or 2 or 4); + if (nbChunks == 1) { - Vector128 chunk = Sse2.LoadVector128((sbyte*)(Vector128*)(void*)(src + 16 * i)); - Vector128 equalMask = Sse2.CompareEqual(chunk, comparisonMask); - matches[i] = Sse2.MoveMask(equalMask); + Vector128 chunk0 = Sse2.LoadVector128(src); + Vector128 equalMask0 = Sse2.CompareEqual(chunk0, comparisonMask); + int matches0 = Sse2.MoveMask(equalMask0); + return BitOperations.RotateRight((ushort)matches0, (int)head); } - if (nbChunks == 1) - return BitOperations.RotateRight((ushort)matches[0], (int)head); if (nbChunks == 2) - return BitOperations.RotateRight((uint)matches[1] << 16 | (uint)matches[0], (int)head); - assert(nbChunks == 4); - return BitOperations.RotateRight((ulong)matches[3] << 48 | (ulong)matches[2] << 32 | (ulong)matches[1] << 16 | (ulong)matches[0], (int)head); + { + Vector128 chunk0 = Sse2.LoadVector128(src); + Vector128 equalMask0 = Sse2.CompareEqual(chunk0, comparisonMask); + int matches0 = Sse2.MoveMask(equalMask0); + Vector128 chunk1 = Sse2.LoadVector128(src + 16); + Vector128 equalMask1 = Sse2.CompareEqual(chunk1, comparisonMask); + int matches1 = Sse2.MoveMask(equalMask1); + return BitOperations.RotateRight((uint)matches1 << 16 | (uint)matches0, (int)head); + } + + { + Vector128 chunk0 = Sse2.LoadVector128(src); + Vector128 equalMask0 = Sse2.CompareEqual(chunk0, comparisonMask); + int matches0 = Sse2.MoveMask(equalMask0); + Vector128 chunk1 = Sse2.LoadVector128(src + 16 * 1); + Vector128 equalMask1 = Sse2.CompareEqual(chunk1, comparisonMask); + int matches1 = Sse2.MoveMask(equalMask1); + Vector128 chunk2 = Sse2.LoadVector128(src + 16 * 2); + Vector128 equalMask2 = Sse2.CompareEqual(chunk2, comparisonMask); + int matches2 = Sse2.MoveMask(equalMask2); + Vector128 chunk3 = Sse2.LoadVector128(src + 16 * 3); + Vector128 equalMask3 = Sse2.CompareEqual(chunk3, comparisonMask); + int matches3 = Sse2.MoveMask(equalMask3); + return BitOperations.RotateRight((ulong)matches3 << 48 | (ulong)matches2 << 32 | (ulong)matches1 << 16 | (uint)matches0, (int)head); + } } #endif @@ -993,23 +1027,56 @@ private static ulong ZSTD_row_getMatchMask(byte* tagRow, byte tag, uint headGrou { if (rowEntries == 16) { + /* vshrn_n_u16 shifts by 4 every u16 and narrows to 8 lower bits. + * After that groups of 4 bits represent the equalMask. We lower + * all bits except the highest in these groups by doing AND with + * 0x88 = 0b10001000. + */ Vector128 chunk = AdvSimd.LoadVector128(src); Vector128 equalMask = AdvSimd.CompareEqual(chunk, AdvSimd.DuplicateToVector128(tag)).As(); - Vector128 t0 = AdvSimd.ShiftLeftLogical(equalMask, 7); - Vector128 t1 = AdvSimd.ShiftRightAndInsert(t0, t0, 14).As(); - Vector128 t2 = AdvSimd.ShiftRightLogical(t1, 14).As(); - Vector128 t3 = AdvSimd.ShiftRightLogicalAdd(t2, t2, 28).As(); - ushort hi = AdvSimd.Extract(t3, 8); - ushort lo = AdvSimd.Extract(t3, 0); - return BitOperations.RotateRight((ushort)(hi << 8 | lo), (int)headGrouped); + Vector64 res = AdvSimd.ShiftRightLogicalNarrowingLower(equalMask, 4); + ulong matches = res.As().GetElement(0); + return BitOperations.RotateRight(matches, (int)headGrouped) & 0x8888888888888888; } else if (rowEntries == 32) { - // todo, there is no vld2q_u16 in c# +#if NET9_0_OR_GREATER + if (AdvSimd.Arm64.IsSupported) + { + /* Same idea as with rowEntries == 16 but doing AND with + * 0x55 = 0b01010101. + */ + (Vector128 chunk0, Vector128 chunk1) = AdvSimd.Arm64.LoadVector128x2AndUnzip((ushort*)src); + Vector128 dup = AdvSimd.DuplicateToVector128(tag); + Vector64 t0 = AdvSimd.ShiftRightLogicalNarrowingLower(AdvSimd.CompareEqual(chunk0.As(), dup).As(), 6); + Vector64 t1 = AdvSimd.ShiftRightLogicalNarrowingLower(AdvSimd.CompareEqual(chunk1.As(), dup).As(), 6); + Vector64 res = AdvSimd.ShiftLeftAndInsert(t0, t1, 4); + ulong matches = res.As().GetElement(0); + return BitOperations.RotateRight(matches, (int)headGrouped) & 0x5555555555555555; + } +#endif } else { /* rowEntries == 64 */ - // todo, there is no vld4q_u8 in c# +#if NET9_0_OR_GREATER + if (AdvSimd.Arm64.IsSupported) + { + (Vector128 chunk0, Vector128 chunk1, Vector128 chunk2, Vector128 chunk3) = AdvSimd.Arm64.LoadVector128x4AndUnzip(src); + Vector128 dup = AdvSimd.DuplicateToVector128(tag); + Vector128 cmp0 = AdvSimd.CompareEqual(chunk0, dup); + Vector128 cmp1 = AdvSimd.CompareEqual(chunk1, dup); + Vector128 cmp2 = AdvSimd.CompareEqual(chunk2, dup); + Vector128 cmp3 = AdvSimd.CompareEqual(chunk3, dup); + + Vector128 t0 = AdvSimd.ShiftRightAndInsert(cmp1, cmp0, 1); + Vector128 t1 = AdvSimd.ShiftRightAndInsert(cmp3, cmp2, 1); + Vector128 t2 = AdvSimd.ShiftRightAndInsert(t1, t0, 2); + Vector128 t3 = AdvSimd.ShiftRightAndInsert(t2, t2, 4); + Vector64 t4 = AdvSimd.ShiftRightLogicalNarrowingLower(t3.As(), 4); + ulong matches = t4.As().GetElement(0); + return BitOperations.RotateRight(matches, (int) headGrouped); + } +#endif } } #endif