Skip to content

Commit

Permalink
remove use of const cast in nnue_feature_transformer.h
Browse files Browse the repository at this point in the history
no functional change
  • Loading branch information
xu-shawn committed Dec 28, 2024
1 parent f656fdf commit dec6dff
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,39 +256,39 @@ class FeatureTransformer {
#endif
}

void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) const {
static void permute_weights([[maybe_unused]] WeightType* weights,
[[maybe_unused]] BiasType* biases,
[[maybe_unused]] void (*order_fn)(uint64_t*)) {
#if defined(USE_AVX2)
#if defined(USE_AVX512)
constexpr IndexType di = 16;
#else
constexpr IndexType di = 8;
#endif
uint64_t* b = reinterpret_cast<uint64_t*>(const_cast<BiasType*>(&biases[0]));
uint64_t* b = reinterpret_cast<uint64_t*>(&biases[0]);
for (IndexType i = 0; i < HalfDimensions * sizeof(BiasType) / sizeof(uint64_t); i += di)
order_fn(&b[i]);

for (IndexType j = 0; j < InputDimensions; ++j)
{
uint64_t* w =
reinterpret_cast<uint64_t*>(const_cast<WeightType*>(&weights[j * HalfDimensions]));
uint64_t* w = reinterpret_cast<uint64_t*>(&weights[j * HalfDimensions]);
for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(uint64_t);
i += di)
order_fn(&w[i]);
}
#endif
}

inline void scale_weights(bool read) const {
static void scale_weights(WeightType* weights, BiasType* biases, bool read) {
for (IndexType j = 0; j < InputDimensions; ++j)
{
WeightType* w = const_cast<WeightType*>(&weights[j * HalfDimensions]);
WeightType* w = &weights[j * HalfDimensions];
for (IndexType i = 0; i < HalfDimensions; ++i)
w[i] = read ? w[i] * 2 : w[i] / 2;
}

BiasType* b = const_cast<BiasType*>(biases);
for (IndexType i = 0; i < HalfDimensions; ++i)
b[i] = read ? b[i] * 2 : b[i] / 2;
biases[i] = read ? biases[i] * 2 : biases[i] / 2;
}

// Read network parameters
Expand All @@ -298,23 +298,29 @@ class FeatureTransformer {
read_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
scale_weights(true);
permute_weights(weights, biases, inverse_order_packs);
scale_weights(weights, biases, true);
return !stream.fail();
}

// Write network parameters
bool write_parameters(std::ostream& stream) const {
BiasType* biasesToWrite = new BiasType[HalfDimensions];
WeightType* weightsToWrite = new WeightType[HalfDimensions * InputDimensions];

permute_weights(order_packs);
scale_weights(false);
std::memcpy(biasesToWrite, biases, sizeof(BiasType) * HalfDimensions);
std::memcpy(weightsToWrite, weights, sizeof(WeightType) * HalfDimensions * InputDimensions);

write_leb_128<BiasType>(stream, biases, HalfDimensions);
write_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
permute_weights(weightsToWrite, biasesToWrite, order_packs);
scale_weights(weightsToWrite, biasesToWrite, false);

write_leb_128<BiasType>(stream, biasesToWrite, HalfDimensions);
write_leb_128<WeightType>(stream, weightsToWrite, HalfDimensions * InputDimensions);
write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
scale_weights(true);
delete[] biasesToWrite;
delete[] weightsToWrite;

return !stream.fail();
}

Expand Down

0 comments on commit dec6dff

Please sign in to comment.