Disable ISliceLayer in TRT 5.1 until the bug is fixed.
PiperOrigin-RevId: 237892535
This commit is contained in:
parent
73cdb00c26
commit
67759c1a25
@ -311,9 +311,9 @@ Status Converter::GetTrtBroadcastShape(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
|
const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
|
||||||
auto compute_output_dims =
|
auto compute_output_dims = [](const TRT_TensorOrWeights& input,
|
||||||
[](const TRT_TensorOrWeights& input, int broadcast_num_dims,
|
int broadcast_num_dims, int* output_dims_array,
|
||||||
int* output_dims_array, nvinfer1::Dims* output_dims) {
|
nvinfer1::Dims* output_dims) {
|
||||||
const nvinfer1::Dims input_dims = input.GetTrtDims();
|
const nvinfer1::Dims input_dims = input.GetTrtDims();
|
||||||
std::fill(output_dims_array, output_dims_array + max_nb_dims, 1);
|
std::fill(output_dims_array, output_dims_array + max_nb_dims, 1);
|
||||||
std::copy(input_dims.d, input_dims.d + input_dims.nbDims,
|
std::copy(input_dims.d, input_dims.d + input_dims.nbDims,
|
||||||
@ -2296,7 +2296,11 @@ Status ConvertStridedSliceHelper(OpConverterParams* params,
|
|||||||
}
|
}
|
||||||
// TRT 5.1 adds a slice layer. For older versions, we attempt to use the
|
// TRT 5.1 adds a slice layer. For older versions, we attempt to use the
|
||||||
// padding layer with negative padding.
|
// padding layer with negative padding.
|
||||||
#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1)
|
#if (NV_TENSORRT_MAJOR > 5 || \
|
||||||
|
(NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1)) && \
|
||||||
|
0
|
||||||
|
// TODO(laigd): TRT 5.1 RC has a bug when ISliceLayer is used along with
|
||||||
|
// IConcatenationLayer, so disable ISliceLayer for now until it's fixed.
|
||||||
// Use ISliceLayer.
|
// Use ISliceLayer.
|
||||||
nvinfer1::Dims begin_dims, size_dims, stride_dims;
|
nvinfer1::Dims begin_dims, size_dims, stride_dims;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims,
|
TF_RETURN_IF_ERROR(TensorShapeArrayToTrtDims(begin, &begin_dims,
|
||||||
|
Loading…
Reference in New Issue
Block a user