Skip to content

Commit

Permalink
wgsl: Optimize memory allocations when quantizing data (#3052)
Browse files Browse the repository at this point in the history
This improves my local gen_cache run time from 1148.2532s to 240.2233s

Fixes #3051
  • Loading branch information
zoddicus authored Oct 5, 2023
1 parent 4619a2b commit b0ea37d
Showing 1 changed file with 21 additions and 23 deletions.
44 changes: 21 additions & 23 deletions src/webgpu/util/math.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@ import {
} from '../../external/petamoriken/float16/float16.js';

import { kBit, kValue } from './constants.js';
import {
f32,
f16,
floatBitsToNumber,
i32,
kFloat16Format,
kFloat32Format,
u32,
} from './conversion.js';
import { floatBitsToNumber, i32, kFloat16Format, kFloat32Format, u32 } from './conversion.js';

/**
* A multiple of 8 guaranteed to be way too large to allocate (just under 8 pebibytes).
Expand Down Expand Up @@ -529,21 +521,20 @@ export function correctlyRoundedF32(n: number): number[] {

// f32 finite
if (n <= kValue.f32.positive.max && n >= kValue.f32.negative.min) {
const n_32 = new Float32Array([n])[0];
const converted: number = n_32;
if (n === converted) {
const n_32 = quantizeToF32(n);
if (n === n_32) {
// n is precisely expressible as a f32, so should not be rounded
return [n];
}

if (converted > n) {
if (n_32 > n) {
// n_32 rounded towards +inf, so is after n
const other = nextAfterF32(n_32, 'negative', 'no-flush');
return [other, converted];
return [other, n_32];
} else {
// n_32 rounded towards -inf, so is before n
const other = nextAfterF32(n_32, 'positive', 'no-flush');
return [converted, other];
return [n_32, other];
}
}

Expand Down Expand Up @@ -598,21 +589,20 @@ export function correctlyRoundedF16(n: number): number[] {

// f16 finite
if (n <= kValue.f16.positive.max && n >= kValue.f16.negative.min) {
const n_16 = new Float16Array([n])[0];
const converted: number = n_16;
if (n === converted) {
const n_16 = quantizeToF16(n);
if (n === n_16) {
// n is precisely expressible as a f16, so should not be rounded
return [n];
}

if (converted > n) {
if (n_16 > n) {
// n_16 rounded towards +inf, so is after n
const other = nextAfterF16(n_16, 'negative', 'no-flush');
return [other, converted];
return [other, n_16];
} else {
// n_16 rounded towards -inf, so is before n
const other = nextAfterF16(n_16, 'positive', 'no-flush');
return [converted, other];
return [n_16, other];
}
}

Expand Down Expand Up @@ -2004,14 +1994,22 @@ export interface QuantizeFunc {
(num: number): number;
}

/** Statically allocate working data, so it doesn't need per-call creation */
const quantizeToF32Data = new Float32Array(new ArrayBuffer(4));

/** @returns the closest 32-bit floating point value to the input */
export function quantizeToF32(num: number): number {
return f32(num).value as number;
quantizeToF32Data[0] = num;
return quantizeToF32Data[0];
}

/** Statically allocate working data, so it doesn't need per-call creation */
const quantizeToF16Data = new Float16Array(new ArrayBuffer(2));

/** @returns the closest 16-bit floating point value to the input */
export function quantizeToF16(num: number): number {
return f16(num).value as number;
quantizeToF16Data[0] = num;
return quantizeToF16Data[0];
}

/** @returns the closest 32-bit signed integer value to the input */
Expand Down

0 comments on commit b0ea37d

Please sign in to comment.