diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..5b7384c98f18 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -56,6 +56,8 @@ _REQUIRED_XLA_VERSION = "2.2" _REQUIRED_XFORMERS_VERSION = "0.0.29" +logger = get_logger(__name__) # pylint: disable=invalid-name + _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() _CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) @@ -67,8 +69,21 @@ if _CAN_USE_FLASH_ATTN: - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward + try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward + except (ImportError, OSError, RuntimeError) as e: + # Handle ABI mismatch or other import failures gracefully. + # This can happen when flash_attn was compiled against a different PyTorch version. + logger.warning( + f"flash_attn is installed but failed to import: {e}. " + f"Falling back to native PyTorch attention." + ) + _CAN_USE_FLASH_ATTN = False + flash_attn_func = None + flash_attn_varlen_func = None + _wrapped_flash_attn_backward = None + _wrapped_flash_attn_forward = None else: flash_attn_func = None flash_attn_varlen_func = None @@ -77,26 +92,47 @@ if _CAN_USE_FLASH_ATTN_3: - from flash_attn_interface import flash_attn_func as flash_attn_3_func - from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.") + _CAN_USE_FLASH_ATTN_3 = False + flash_attn_3_func = None + flash_attn_3_varlen_func = None else: flash_attn_3_func = None flash_attn_3_varlen_func = None if _CAN_USE_AITER_ATTN: - from aiter import flash_attn_func as aiter_flash_attn_func + try: + from aiter import flash_attn_func as aiter_flash_attn_func + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"aiter failed to import: {e}. Falling back to native attention.") + _CAN_USE_AITER_ATTN = False + aiter_flash_attn_func = None else: aiter_flash_attn_func = None if _CAN_USE_SAGE_ATTN: - from sageattention import ( - sageattn, - sageattn_qk_int8_pv_fp8_cuda, - sageattn_qk_int8_pv_fp8_cuda_sm90, - sageattn_qk_int8_pv_fp16_cuda, - sageattn_qk_int8_pv_fp16_triton, - sageattn_varlen, - ) + try: + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.") + _CAN_USE_SAGE_ATTN = False + sageattn = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_varlen = None else: sageattn = None sageattn_qk_int8_pv_fp16_cuda = None @@ -107,26 +143,48 @@ if _CAN_USE_FLEX_ATTN: - # We cannot import the flex_attention function from the package directly because it is expected (from the - # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the - # compiled function. - import torch.nn.attention.flex_attention as flex_attention + try: + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + import torch.nn.attention.flex_attention as flex_attention + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.") + _CAN_USE_FLEX_ATTN = False + flex_attention = None +else: + flex_attention = None if _CAN_USE_NPU_ATTN: - from torch_npu import npu_fusion_attention + try: + from torch_npu import npu_fusion_attention + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.") + _CAN_USE_NPU_ATTN = False + npu_fusion_attention = None else: npu_fusion_attention = None if _CAN_USE_XLA_ATTN: - from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention + try: + from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.") + _CAN_USE_XLA_ATTN = False + xla_flash_attention = None else: xla_flash_attention = None if _CAN_USE_XFORMERS_ATTN: - import xformers.ops as xops + try: + import xformers.ops as xops + except (ImportError, OSError, RuntimeError) as e: + logger.warning(f"xformers failed to import: {e}. Falling back to native attention.") + _CAN_USE_XFORMERS_ATTN = False + xops = None else: xops = None @@ -152,8 +210,6 @@ def wrap(func): _register_fake = register_fake_no_op -logger = get_logger(__name__) # pylint: disable=invalid-name - # TODO(aryan): Add support for the following: # - Sage Attention++ # - block sparse, radial and other attention methods