Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal Error (Could not find any implementation for node {...Tensordot/Reshape}) failure of TensorRT 8.6 when running generate INT8 engine on GPU RTX3080 #4291

Open
yjiangling opened this issue Dec 19, 2024 · 5 comments
Assignees
Labels
Engine Build Issues with engine build quantization Issues related to Quantization triaged Issue has been triaged by maintainers

Comments

@yjiangling
Copy link

Hi,
I'm using TensorRT8.6 to conduct int8 calibration and genereate tensorrt engine with polygraphy tools like this

polygraphy convert onnx_model/model.onnx --trt-min-shapes xs:[1,1120] xlen:[1] --trt-opt-shapes xs:[1,160000] xlen:[1] --trt-max-shapes xs:[1,480000] xlen:[1] --int8 --data-loader-script data_loader.py --calibration-cache trt86_minmax_calib.cache --calib-base-cls IInt8MinMaxCalibrator --output trt_model/trt86_minmax_int8.plan

, but it always give the following error:

[W] 'colored' module is not installed, will not use colors when logging. To enable colors, please install the 'colored' module: python3 -m pip install colored
[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[I]     Configuring with profiles: [Profile().add('xs', min=[1, 1120], opt=[1, 160000], max=[1, 480000]).add('xlen', min=[1], opt=[1], max=[1])]
[W] TensorRT does not currently support using dynamic shapes during calibration. The `OPT` shapes from the calibration profile will be used for tensors with dynamic shapes. Calibration data is expected to conform to those shapes. 
[I] Building engine with configuration:
    Flags                  | [INT8]
    Engine Capability      | EngineCapability.DEFAULT
    Memory Pools           | [WORKSPACE: 10002.44 MiB, TACTIC_DRAM: 10002.44 MiB]
    Tactic Sources         | [CUBLAS, CUBLAS_LT, CUDNN, EDGE_MASK_CONVOLUTIONS, JIT_CONVOLUTIONS]
    Profiling Verbosity    | ProfilingVerbosity.DETAILED
    Preview Features       | [FASTER_DYNAMIC_SHAPES_0805, DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
    Calibrator             | Calibrator(<generator object load_data at 0x7fd95cd5c430>, cache='trt86_minmax_calib.cache', BaseClass=<class 'tensorrt.tensorrt.IInt8MinMaxCalibrator'>)
[I] Saving calibration cache to trt86_minmax_calib.cache
[W] Missing scale and zero-point for tensor xlen, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 1) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor sub:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 9) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor floordiv:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor stft/frame/range_1:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 59) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor stft/frame/mul_1:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor stft/frame/Reshape_2:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 70) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor stft/frame/add_2:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 82) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor add:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/ExpandDims:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/Cast:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 107) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor (Unnamed Layer* 112) [Constant]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/Range:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
[W] Missing scale and zero-point for tensor SequenceMask/Less:0, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
formats.cpp:2379: DCHECK(desired_so.size() == t->dim_count()) failed. 
[E] 10: Could not find any implementation for node {ForeignNode[(Unnamed Layer* 82) [Shuffle]_output[Constant]...Tensordot/Reshape]}.
[E] 10: [optimizer.cpp::computeCosts::3869] Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[(Unnamed Layer* 82) [Shuffle]_output[Constant]...Tensordot/Reshape]}.)
[!] Invalid Engine. Please ensure the engine was built correctly

The convert onnx model is generated from the saved model, which produced with codes as below:

import librosa
import tensorflow as tf

xs = tf.placeholder(tf.float32, shape=[None, None], name='xs')
xlen = tf.placeholder(tf.int32, shape=[None], name='xlen')
spectrogram = tf.square(tf.abs(tf.signal.stft(xs, frame_length=400, frame_step=160, fft_length=512)))
weight = librosa.filters.mel(sr=16000, n_fft=512, n_mels=80).T  # (257,80)
weight = tf.convert_to_tensor(weight, tf.float32)
mel = tf.tensordot(spectrogram, weight, axes=1)
mel_length = (xlen - 400) // 160 + 1
mask = tf.sequence_mask(mel_length, maxlen=tf.shape(mel)[1], dtype=tf.float32)
mel = mel * tf.expand_dims(mask, axis=-1)
ys = tf.identity(mel, name='ys')

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())

	tf.saved_model.simple_save(sess,
		'./saved_model',
		inputs={"xs": xs, "xlen": xlen},
		outputs={"ys": ys})

Then, convert saved model to onnx model with python3 -m tf2onnx.convert --opset 13 --saved-model ./saved_model/ --output ./onnx_model/model.onnx , what's wrong with it ?

By the way, the file data_loader.py is used for int8 calibration, you can reproduce it like this:

import numpy as np

def load_data(calib_num=100):
	for _ in range(calib_num):
		n = np.random.randint(1000, 5000)
		x = np.random.rand(1, n).astype(np.float32)
		x_len = np.array([n], dtype=np.int32)
		yield {"xs": x, "xlen": x_len}

Anyone can give some helps? Thanks a lot for the help!!!

@yjiangling
Copy link
Author

model.zip

Here is the onnx model file, you can use it directly, or generate a new model with before codes.

@yjiangling yjiangling changed the title Internal Error (Could not find any implementation for node {...Tensordot/Reshape}) failure of TensorRT 8.6 when running generate INT8 egine on GPU RTX3080 Internal Error (Could not find any implementation for node {...Tensordot/Reshape}) failure of TensorRT 8.6 when running generate INT8 engine on GPU RTX3080 Dec 19, 2024
@asfiyab-nvidia asfiyab-nvidia added Engine Build Issues with engine build triaged Issue has been triaged by maintainers quantization Issues related to Quantization labels Dec 19, 2024
@yuanyao-nv
Copy link
Collaborator

@yjiangling I see the TRT version you used is quite old. Would you like to try the latest 10.7 to check if the error still occurs?

@yjiangling
Copy link
Author

yjiangling commented Dec 20, 2024

@yjiangling I see the TRT version you used is quite old. Would you like to try the latest 10.7 to check if the error still occurs?

@yuanyao-nv Yes, I'will try to install the latest docker images and try it again, but for some reason, we must run our models on tensorrt8.6 so far, so could you please help to investigate what's wrong with it ? In my experiments, I found if remove the code mel = mel * tf.expand_dims(mask, axis=-1) , it will successfully generate the tensorrt engine without any error. It seems that there are some tricks in TensorRT8.6... I would greatly appreciate it if you could help to figure out it.

@yjiangling
Copy link
Author

@asfiyab-nvidia Could you please give some helps? many thanks, your partener assigned last week maybe on vacation or too busy, hoping to receive some suggestions from your team, thanks again.

  1. By the way, for some hardware and software requirements in our project, we must use TensorRT8.6 now, and what's more, if we generate fp32 engine directly with int8 calibration and quantilize, everything is OK, so we guess the problem should lie in the calibration and quantification stages.

  2. Even we remove the code mel = mel * tf.expand_dims(mask, axis=-1) and use polygraphy convert to generate the int8 quantilized model, the engine can't get the right result, and when used the trtexec and calibration cache file to conduct calibration and quantilize, it failed again.

trtexec --onnx=onnx_model/model.onnx --minShapes=xs:1x1120,xlen:1 --optShapes=xs:1x160000,xlen:1 --maxShapes=xs:1x480000,xlen:1 --minShapesCalib=xs:1x1120,xlen:1 --optShapesCalib=xs:1x160000,xlen:1 --maxShapesCalib=xs:1x480000,xlen:1 --workspace=10240 --int8 --calib=trt86_minmax_calib.cache --saveEngine=trt_model/trt86_minmax_int8.plan --verbose --buildOnly > trt_model/result-INT8.txt

So, it seems that there is something wrong with the cailbration cache file generate from polygraphy convert before?

@yjiangling
Copy link
Author

@lix19937 May I have your help? Many thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Engine Build Issues with engine build quantization Issues related to Quantization triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants