Skip to content

Commit

Permalink
fix YaoBlocks.mat for rotation gates with TracedRNumber parameters (
Browse files Browse the repository at this point in the history
#306)

* fix `YaoBlocks.mat` for rotation gates with `TracedRNumber` parameters

* Update ext/ReactantYaoBlocksExt.jl
  • Loading branch information
mofeing authored Nov 27, 2024
1 parent 856da5d commit 4981557
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"

[compat]
AbstractFFTs = "1.5"
Expand Down
37 changes: 37 additions & 0 deletions ext/ReactantYaoBlocksExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
module ReactantYaoBlocksExt

using Reactant
using YaoBlocks

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M[1, 1] = cos(R.theta / 2)
M[2, 2] = cos(R.theta / 2)
M[1, 2] = -im * sin(R.theta / 2)
M[2, 1] = -im * sin(R.theta / 2)
return M
end

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M[1, 1] = cos(R.theta / 2)
M[2, 2] = cos(R.theta / 2)
M[1, 2] = -sin(R.theta / 2)
M[2, 1] = sin(R.theta / 2)
return M
end

function YaoBlocks.mat(
::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate}
) where {D,T,S}
M = Reactant.broadcast_to_size(zero(T), (2, 2))
M[1, 1] = exp(-im * R.theta / 2)
M[2, 2] = exp(im * R.theta / 2)
return M
end

end

1 comment on commit 4981557

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 4981557 Previous: 856da5d Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1257741783 ns 1262529790 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1443950901 ns 1295960067 ns 1.11
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1385605290 ns 1388414450 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2794781252 ns 2845010519 ns 0.98
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 203499640 ns 216191284 ns 0.94
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 5710482855 ns 6553949825 ns 0.87
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5341487680 ns 5529491940 ns 0.97
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 6062348395 ns 5207076241 ns 1.16
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7286146462 ns 7423222920 ns 0.98
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 34742615164 ns 29111245934 ns 1.19
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1198386252 ns 1210144232 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1183740364 ns 1203258421 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1226242113 ns 1201314375 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2824009338 ns 2772640660 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8581758.5 ns 8101273 ns 1.06
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1659437123 ns 1562542896 ns 1.06
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1663747727 ns 1590422736 ns 1.05
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1575216589 ns 1593527212 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3330508539 ns 3448787500 ns 0.97
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2697236030 ns 2352203752.5 ns 1.15
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1177894801 ns 1283921788 ns 0.92
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1231369417 ns 1289501443 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1281022614 ns 1236773657 ns 1.04
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2992466689 ns 2994979934 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 22743533 ns 21119477 ns 1.08
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2143636724 ns 2189596614 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2146473226 ns 2167135675 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2157303648 ns 2152353497 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3959878673 ns 3988887387 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5865019185 ns 5727971331.5 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1387842900 ns 1351251105 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1256093379 ns 1334983286 ns 0.94
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1257804932 ns 1326564260 ns 0.95
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3119512525 ns 3006970362 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 6970194 ns 7156082 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1427711032 ns 1442181095 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1416276590 ns 1449779916 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1497474638 ns 1440104385 ns 1.04
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3162263527 ns 3206233925 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1224458504 ns 1453999347 ns 0.84
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1258743352 ns 1278843849 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1272174301 ns 1276785330 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1286613721 ns 1466486039 ns 0.88
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2941431002 ns 3029520232 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 12308639.5 ns 11308170.5 ns 1.09
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1728521529 ns 1739952511 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1726871692 ns 1708928970 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1724333072 ns 1699255869 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3485347197 ns 3445714498 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3314985446 ns 3319646969 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1296002633 ns 1277942983 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1304207360 ns 1251667936 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1255941581 ns 1262343075 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3046234878 ns 3208271409 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 27396872 ns 25656532 ns 1.07
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2236205101 ns 2171969273 ns 1.03
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2198709936 ns 2152154812 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2296730805 ns 2160330860 ns 1.06
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3958176139 ns 3932782743 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5608465822 ns 6171257189 ns 0.91
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1251122132 ns 1288285171 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1342649088 ns 1266265410 ns 1.06
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1266853946 ns 1277150659 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3134977417 ns 3063914527 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 53144411 ns 50680627 ns 1.05
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3046201375 ns 2975571725 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3063894685 ns 2986093155 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3087462133 ns 2978196812 ns 1.04
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4895902746 ns 4911960375 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 9797674342 ns 8980759371 ns 1.09
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1295665825 ns 1443741989 ns 0.90
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1436872653 ns 1518889344 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1277008406 ns 1290151359 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3239595052 ns 3140508919 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 71339137 ns 68715424.5 ns 1.04
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3189935441 ns 3162377720 ns 1.01
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3207736040 ns 3140139532 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3222417327 ns 3144637121 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5319487172 ns 5092189137 ns 1.04
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 12850280060 ns 12329016315 ns 1.04
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1274372415 ns 1326072250 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1287517573 ns 1283998354 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1311075350 ns 1283550328 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3155538600 ns 3074369011 ns 1.03
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 20891523 ns 19742705 ns 1.06
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1909585662 ns 1880491780 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1907821565 ns 1867141478 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1918208063 ns 2022272204 ns 0.95
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3741288084 ns 3598201159 ns 1.04
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3617112173 ns 4440776854 ns 0.81

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.