From f3a4f3db768d46defc16de48208107db1b32159d Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Wed, 8 Nov 2023 06:37:50 +0100 Subject: [PATCH] PyO3: Add optional `candle.onnx` module (#1282) * Start onnx integration * Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx * Implement ONNXModel * `fmt` * add `onnx` flag to python ci * Pin `protoc` to `25.0` * Setup `protoc` in wheel builds * Build wheels with `onnx` * Install `protoc` in manylinux containers * `apt` -> `yum` * Download `protoc` via bash script * Back to `manylinux: auto` * Disable `onnx` builds for linux --- .github/workflows/maturin.yml | Bin 5304 -> 6672 bytes .github/workflows/python.yml | 8 +- candle-onnx/src/eval.rs | 2 +- candle-onnx/src/lib.rs | 2 +- candle-pyo3/Cargo.toml | 3 + candle-pyo3/py_src/candle/onnx/__init__.py | 5 + candle-pyo3/py_src/candle/onnx/__init__.pyi | 89 ++++++++ candle-pyo3/src/lib.rs | 22 +- candle-pyo3/src/onnx.rs | 212 ++++++++++++++++++++ candle-pyo3/src/utils.rs | 6 + 10 files changed, 343 insertions(+), 6 deletions(-) create mode 100644 candle-pyo3/py_src/candle/onnx/__init__.py create mode 100644 candle-pyo3/py_src/candle/onnx/__init__.pyi create mode 100644 candle-pyo3/src/onnx.rs create mode 100644 candle-pyo3/src/utils.rs diff --git a/.github/workflows/maturin.yml b/.github/workflows/maturin.yml index 1413f01475fb1a0baea85c10e3c05c82e19b415d..46bdb903da63c434e0e188a438f8a6b6e8478498 100644 GIT binary patch delta 1029 zcmd6jy-EW?6ot>MOA3uf0%Bp2MTHwSXzrET-MM;8%r3*AS|veY-1C< z+-oRfA35H4P~_dA%0dzuEFg_Jp1PhqRV(!DQJbT>j662@Qs_8r*r=n13Id$L=joJ5 zEp~32I7d>xiK|4Nmx2f`}ZLr?dji+(JW$|6{N&m$~ewc}Z z`jtnwttUCwKEqvojUzX$A-CYvIOt&`a*c`4o{vU8XYY_TIE7C9$oK{=koc^x;;=1t;pOaSE`8VUda diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index bf85f5e512..be9b917ec2 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -39,6 +39,12 @@ jobs: path: ~/.cargo/registry key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }} + - name: Install Protoc + uses: arduino/setup-protoc@v2 + with: + version: "25.0" + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Install working-directory: ./candle-pyo3 run: | @@ -46,7 +52,7 @@ jobs: source .env/bin/activate pip install -U pip pip install pytest maturin black - python -m maturin develop -r + python -m maturin develop -r --features onnx - name: Check style working-directory: ./candle-pyo3 diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 51e2aa0c73..b7e325e134 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -98,7 +98,7 @@ fn get_attr_opt<'a, T: Attr + ?Sized>( } } -fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result { +pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result { let dims: Vec = t.dims.iter().map(|&x| x as usize).collect(); match DataType::try_from(t.data_type) { Ok(DataType::Int32) => { diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs index 1002a2c868..efd6f7600f 100644 --- a/candle-onnx/src/lib.rs +++ b/candle-onnx/src/lib.rs @@ -5,7 +5,7 @@ pub mod onnx { include!(concat!(env!("OUT_DIR"), "/onnx.rs")); } -mod eval; +pub mod eval; pub use eval::{dtype, simple_eval}; pub fn read_file>(p: P) -> Result { diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index b04524040e..f79277f291 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -17,6 +17,7 @@ crate-type = ["cdylib"] accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } +candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true} half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } @@ -29,3 +30,5 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate"] cuda = ["candle/cuda"] mkl = ["dep:intel-mkl-src","candle/mkl"] +onnx = ["dep:candle-onnx"] + diff --git a/candle-pyo3/py_src/candle/onnx/__init__.py b/candle-pyo3/py_src/candle/onnx/__init__.py new file mode 100644 index 0000000000..856ecd7d97 --- /dev/null +++ b/candle-pyo3/py_src/candle/onnx/__init__.py @@ -0,0 +1,5 @@ +# Generated content DO NOT EDIT +from .. import onnx + +ONNXModel = onnx.ONNXModel +ONNXTensorDescription = onnx.ONNXTensorDescription diff --git a/candle-pyo3/py_src/candle/onnx/__init__.pyi b/candle-pyo3/py_src/candle/onnx/__init__.pyi new file mode 100644 index 0000000000..8ce1b3aaca --- /dev/null +++ b/candle-pyo3/py_src/candle/onnx/__init__.pyi @@ -0,0 +1,89 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device, Scalar, Index, Shape +from candle import Tensor, DType, QTensor + +class ONNXModel: + """ + A wrapper around an ONNX model. + """ + + def __init__(self, path: str): + pass + @property + def doc_string(self) -> str: + """ + The doc string of the model. + """ + pass + @property + def domain(self) -> str: + """ + The domain of the operator set of the model. + """ + pass + def initializers(self) -> Dict[str, Tensor]: + """ + Get the weights of the model. + """ + pass + @property + def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: + """ + The inputs of the model. + """ + pass + @property + def ir_version(self) -> int: + """ + The version of the IR this model targets. + """ + pass + @property + def model_version(self) -> int: + """ + The version of the model. + """ + pass + @property + def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: + """ + The outputs of the model. + """ + pass + @property + def producer_name(self) -> str: + """ + The producer of the model. + """ + pass + @property + def producer_version(self) -> str: + """ + The version of the producer of the model. + """ + pass + def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Run the model on the given inputs. + """ + pass + +class ONNXTensorDescription: + """ + A wrapper around an ONNX tensor description. + """ + + @property + def dtype(self) -> DType: + """ + The data type of the tensor. + """ + pass + @property + def shape(self) -> Tuple[Union[int, str, Any]]: + """ + The shape of the tensor. + """ + pass diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ddd58fbe16..05a786efa0 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -19,12 +19,14 @@ extern crate accelerate_src; use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +mod utils; +use utils::wrap_err; + mod shape; use shape::{PyShape, PyShapeWithHole}; -pub fn wrap_err(err: ::candle::Error) -> PyErr { - PyErr::new::(format!("{err:?}")) -} +#[cfg(feature = "onnx")] +mod onnx; #[derive(Clone, Debug)] #[pyclass(name = "Tensor")] @@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { Ok(()) } +#[cfg(feature = "onnx")] +fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + use onnx::{PyONNXModel, PyONNXTensorDescriptor}; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + #[pymodule] fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let utils = PyModule::new(py, "utils")?; @@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { let nn = PyModule::new(py, "functional")?; candle_functional_m(py, nn)?; m.add_submodule(nn)?; + #[cfg(feature = "onnx")] + { + let onnx = PyModule::new(py, "onnx")?; + candle_onnx_m(py, onnx)?; + m.add_submodule(onnx)?; + } m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/candle-pyo3/src/onnx.rs b/candle-pyo3/src/onnx.rs new file mode 100644 index 0000000000..b9a0eb2272 --- /dev/null +++ b/candle-pyo3/src/onnx.rs @@ -0,0 +1,212 @@ +use std::collections::HashMap; + +use crate::utils::wrap_err; +use crate::{PyDType, PyTensor}; +use candle_onnx::eval::{dtype, get_tensor, simple_eval}; +use candle_onnx::onnx::tensor_proto::DataType; +use candle_onnx::onnx::tensor_shape_proto::dimension::Value; +use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue}; +use candle_onnx::onnx::{ModelProto, ValueInfoProto}; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyList, PyTuple}; + +#[derive(Clone, Debug)] +#[pyclass(name = "ONNXTensorDescription")] +/// A wrapper around an ONNX tensor description. +pub struct PyONNXTensorDescriptor(ONNXTensor); + +#[pymethods] +impl PyONNXTensorDescriptor { + #[getter] + /// The data type of the tensor. + /// &RETURNS&: DType + fn dtype(&self) -> PyResult { + match DataType::try_from(self.0.elem_type) { + Ok(dt) => match dtype(dt) { + Some(dt) => Ok(PyDType(dt)), + None => Err(PyValueError::new_err(format!( + "unsupported 'value' data-type {dt:?}" + ))), + }, + type_ => Err(PyValueError::new_err(format!( + "unsupported input type {type_:?}" + ))), + } + } + + #[getter] + /// The shape of the tensor. + /// &RETURNS&: Tuple[Union[int,str,Any]] + fn shape(&self, py: Python) -> PyResult> { + let shape = PyList::empty(py); + if let Some(d) = &self.0.shape { + for dim in d.dim.iter() { + if let Some(value) = &dim.value { + match value { + Value::DimValue(v) => shape.append(*v)?, + Value::DimParam(s) => shape.append(s.clone())?, + }; + } else { + return Err(PyValueError::new_err("None value in shape")); + } + } + } + Ok(shape.to_tuple().into()) + } + + fn __repr__(&self, py: Python) -> String { + match (self.shape(py), self.dtype()) { + (Ok(shape), Ok(dtype)) => format!( + "TensorDescriptor[shape: {:?}, dtype: {:?}]", + shape.to_string(), + dtype.__str__() + ), + (Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(), + (Err(_), Ok(dtype)) => format!( + "TensorDescriptor[shape: unknown, dtype: {:?}]", + dtype.__str__() + ), + (Ok(shape), Err(_)) => format!( + "TensorDescriptor[shape: {:?}, dtype: unknown]", + shape.to_string() + ), + } + } + + fn __str__(&self, py: Python) -> String { + self.__repr__(py) + } +} + +#[derive(Clone, Debug)] +#[pyclass(name = "ONNXModel")] +/// A wrapper around an ONNX model. +pub struct PyONNXModel(ModelProto); + +fn extract_tensor_descriptions( + value_infos: &[ValueInfoProto], +) -> HashMap { + let mut map = HashMap::new(); + for value_info in value_infos.iter() { + let input_type = match &value_info.r#type { + Some(input_type) => input_type, + None => continue, + }; + let input_type = match &input_type.value { + Some(input_type) => input_type, + None => continue, + }; + + let tensor_type: &ONNXTensor = match input_type { + ONNXValue::TensorType(tt) => tt, + _ => continue, + }; + map.insert( + value_info.name.to_string(), + PyONNXTensorDescriptor(tensor_type.clone()), + ); + } + map +} + +#[pymethods] +impl PyONNXModel { + #[new] + #[pyo3(text_signature = "(self, path:str)")] + /// Load an ONNX model from the given path. + fn new(path: String) -> PyResult { + let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?; + Ok(PyONNXModel(model)) + } + + #[getter] + /// The version of the IR this model targets. + /// &RETURNS&: int + fn ir_version(&self) -> i64 { + self.0.ir_version + } + + #[getter] + /// The producer of the model. + /// &RETURNS&: str + fn producer_name(&self) -> String { + self.0.producer_name.clone() + } + + #[getter] + /// The version of the producer of the model. + /// &RETURNS&: str + fn producer_version(&self) -> String { + self.0.producer_version.clone() + } + + #[getter] + /// The domain of the operator set of the model. + /// &RETURNS&: str + fn domain(&self) -> String { + self.0.domain.clone() + } + + #[getter] + /// The version of the model. + /// &RETURNS&: int + fn model_version(&self) -> i64 { + self.0.model_version + } + + #[getter] + /// The doc string of the model. + /// &RETURNS&: str + fn doc_string(&self) -> String { + self.0.doc_string.clone() + } + + /// Get the weights of the model. + /// &RETURNS&: Dict[str, Tensor] + fn initializers(&self) -> PyResult> { + let mut map = HashMap::new(); + if let Some(graph) = self.0.graph.as_ref() { + for tensor_description in graph.initializer.iter() { + let tensor = get_tensor(tensor_description, tensor_description.name.as_str()) + .map_err(wrap_err)?; + map.insert(tensor_description.name.to_string(), PyTensor(tensor)); + } + } + Ok(map) + } + + #[getter] + /// The inputs of the model. + /// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]] + fn inputs(&self) -> Option> { + if let Some(graph) = self.0.graph.as_ref() { + return Some(extract_tensor_descriptions(&graph.input)); + } + None + } + + #[getter] + /// The outputs of the model. + /// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]] + fn outputs(&self) -> Option> { + if let Some(graph) = self.0.graph.as_ref() { + return Some(extract_tensor_descriptions(&graph.output)); + } + None + } + + #[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")] + /// Run the model on the given inputs. + /// &RETURNS&: Dict[str,Tensor] + fn run(&self, inputs: HashMap) -> PyResult> { + let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect(); + + let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?; + + Ok(result + .into_iter() + .map(|(k, v)| (k.clone(), PyTensor(v))) + .collect()) + } +} diff --git a/candle-pyo3/src/utils.rs b/candle-pyo3/src/utils.rs new file mode 100644 index 0000000000..ad0a76a58f --- /dev/null +++ b/candle-pyo3/src/utils.rs @@ -0,0 +1,6 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +pub fn wrap_err(err: ::candle::Error) -> PyErr { + PyErr::new::(format!("{err:?}")) +}