From a5255dc1a7c6dadddef740127866fee4b0e2fda9 Mon Sep 17 00:00:00 2001 From: JonLuca De Caro Date: Wed, 27 Mar 2024 09:22:35 -0700 Subject: [PATCH] Fix xformers to work on v0.0.22 - 0.0.25 (#136) --- src/sfast/libs/xformers/xformers_attention.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/sfast/libs/xformers/xformers_attention.py b/src/sfast/libs/xformers/xformers_attention.py index f4942d2..deb4ba9 100644 --- a/src/sfast/libs/xformers/xformers_attention.py +++ b/src/sfast/libs/xformers/xformers_attention.py @@ -4,17 +4,21 @@ from xformers import ops from sfast.utils.custom_python_operator import register_custom_python_operator -OP_STR_MAP = { - ops.MemoryEfficientAttentionCutlassFwdFlashBwOp: +OP_STR_MAP = {} + +for attr_name in [ 'MemoryEfficientAttentionCutlassFwdFlashBwOp', - ops.MemoryEfficientAttentionCutlassOp: 'MemoryEfficientAttentionCutlassOp', - ops.MemoryEfficientAttentionFlashAttentionOp: + 'MemoryEfficientAttentionCutlassOp', 'MemoryEfficientAttentionFlashAttentionOp', - ops.MemoryEfficientAttentionOp: 'MemoryEfficientAttentionOp', - ops.MemoryEfficientAttentionTritonFwdFlashBwOp: + 'MemoryEfficientAttentionOp', 'MemoryEfficientAttentionTritonFwdFlashBwOp', - ops.TritonFlashAttentionOp: 'TritonFlashAttentionOp', -} + 'TritonFlashAttentionOp', + 'MemoryEfficientAttentionCkOp', + 'MemoryEfficientAttentionSplitKCkOp' +]: + op_attr = getattr(ops, attr_name, None) + if op_attr is not None: + OP_STR_MAP[op_attr] = attr_name STR_OP_MAP = {v: k for k, v in OP_STR_MAP.items()}