Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Nvidia] Support fp8 to bf16 casting on RTX 4000 series #5544

Merged
merged 2 commits into from
Jan 7, 2025

Conversation

mbrookhart
Copy link
Contributor

I noticed that some of the tests were failing when I was testing on a workstation with a consumer RTX card. Turns out that sm_89 supports fp8, but doesn't support cvt.bf16.f16

From the ptx spec:

cvt.bf16.{u8/s8/u16/s16/u32/s32/u64/s64/f16/f64/bf16}, cvt.{u8/s8/u16/s16/u32/s32/u64/s64/f16/f64}.bf16, and cvt.tf32.f32.{relu}.{rn/rz} require sm_90 or higher.

This adds a path to first convert to fp32 and then bf16 if compute compatibility is < 90,

This is already hit in the tests (specifically several test cases in test core, many variations on dot_scaled in particular).

@mbrookhart mbrookhart requested a review from ptillet as a code owner January 7, 2025 03:58
@ThomasRaoux ThomasRaoux merged commit 4947a95 into triton-lang:main Jan 7, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants