Skip to content

Commit

Permalink
PyO3: Add mkl support (#1159)
Browse files Browse the repository at this point in the history
* Add `mkl` support

* Set `mkl` path on linux
  • Loading branch information
LLukas22 authored Oct 23, 2023
1 parent 86e1803 commit eae94a4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
2 changes: 2 additions & 0 deletions candle-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
intel-mkl-src = { workspace = true, optional = true }

[build-dependencies]
pyo3-build-config = "0.19"

[features]
default = []
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]
48 changes: 36 additions & 12 deletions candle-pyo3/py_src/candle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,51 @@
try:
from .candle import *
except ImportError as e:
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
# If we are in development mode, or we did not bundle the DLLs, we try to locate them here
# PyO3 wont give us any infomration about what DLLs are missing, so we can only try to load the DLLs and re-import the module
logging.warning("DLLs were not bundled with this package. Trying to locate them...")
import os
import platform

# Try to locate CUDA_PATH environment variable
cuda_path = os.environ.get("CUDA_PATH", None)
if cuda_path:
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
if platform.system() == "Windows":
cuda_path = os.path.join(cuda_path, "bin")
def locate_cuda_dlls():
logging.warning("Locating CUDA DLLs...")
# Try to locate CUDA_PATH environment variable
cuda_path = os.environ.get("CUDA_PATH", None)
if cuda_path:
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
if platform.system() == "Windows":
cuda_path = os.path.join(cuda_path, "bin")
else:
cuda_path = os.path.join(cuda_path, "lib64")

logging.warning(f"Adding {cuda_path} to DLL search path...")
os.add_dll_directory(cuda_path)
else:
logging.warning("CUDA_PATH environment variable not found!")

def locate_mkl_dlls():
# Try to locate ONEAPI_ROOT environment variable
oneapi_root = os.environ.get("ONEAPI_ROOT", None)
if oneapi_root:
if platform.system() == "Windows":
mkl_path = os.path.join(
oneapi_root, "compiler", "latest", "windows", "redist", "intel64_win", "compiler"
)
else:
mkl_path = os.path.join(oneapi_root, "mkl", "latest", "lib", "intel64")

logging.warning(f"Adding {mkl_path} to DLL search path...")
os.add_dll_directory(mkl_path)
else:
cuda_path = os.path.join(cuda_path, "lib64")
logging.warning("ONEAPI_ROOT environment variable not found!")

logging.warning(f"Adding {cuda_path} to DLL search path...")
os.add_dll_directory(cuda_path)
locate_cuda_dlls()
locate_mkl_dlls()

try:
from .candle import *
except ImportError as inner_e:
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
raise ImportError("Could not locate DLLs. Please check the documentation for more information.")

__doc__ = candle.__doc__
if hasattr(candle, "__all__"):
Expand Down
3 changes: 3 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use std::sync::Arc;

use half::{bf16, f16};

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};

pub fn wrap_err(err: ::candle::Error) -> PyErr {
Expand Down

0 comments on commit eae94a4

Please sign in to comment.