From 55aee9e55084b309d5a01dae6685d4622482d6df Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 18 May 2020 08:55:02 -0700 Subject: [PATCH] [TF:TRT] Add utilities for converting between TF types and TRT types. PiperOrigin-RevId: 312087947 Change-Id: Ie4c47ab5c6aae97af5a83bba06e3de0637752ecf --- .../tf2tensorrt/convert/convert_nodes_test.cc | 32 ++++++----------- .../compiler/tf2tensorrt/convert/utils.cc | 35 +++++++++++++++++++ .../compiler/tf2tensorrt/convert/utils.h | 3 ++ 3 files changed, 48 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 82c02c17e93..964370af6be 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -137,30 +137,18 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } -nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) { - switch (tf_dtype) { - case DT_FLOAT: - return nvinfer1::DataType::kFLOAT; - case DT_HALF: - return nvinfer1::DataType::kHALF; - case DT_INT32: - return nvinfer1::DataType::kINT32; - default: - QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype); - } +nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) { + nvinfer1::DataType trt_type; + Status status = TfTypeToTrtType(tf_type, &trt_type); + EXPECT_EQ(status, Status::OK()); + return trt_type; } -DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - return DT_FLOAT; - case nvinfer1::DataType::kHALF: - return DT_HALF; - case nvinfer1::DataType::kINT32: - return DT_INT32; - default: - QCHECK(false) << "Unexpected data type " << static_cast(trt_dtype); - } +DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) { + DataType tf_type; + Status status = TrtTypeToTfType(trt_type, &tf_type); + EXPECT_EQ(status, Status::OK()); + return tf_type; } NodeDef MakeNodeDef(const string& name, const string& op, diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index fb3ae6943d3..a4b64ec0dc5 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace tensorrt { @@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, return Status::OK(); } +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) { + switch (tf_type) { + case DT_FLOAT: + *trt_type = nvinfer1::DataType::kFLOAT; + break; + case DT_HALF: + *trt_type = nvinfer1::DataType::kHALF; + break; + case DT_INT32: + *trt_type = nvinfer1::DataType::kINT32; + break; + default: + return errors::Internal("Unsupported tensorflow type"); + } + return Status::OK(); +} + +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) { + switch (trt_type) { + case nvinfer1::DataType::kFLOAT: + *tf_type = DT_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *tf_type = DT_HALF; + break; + case nvinfer1::DataType::kINT32: + *tf_type = DT_INT32; + break; + default: + return errors::Internal("Invalid TRT type"); + } + return Status::OK(); +} + int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { int n_bindings = engine->getNbBindings(); int n_input = 0; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 5d4cf1bb851..59eeb420134 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -106,6 +106,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, bool use_implicit_batch, int batch_size, TensorShape& shape); +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); + // Returns a string that includes compile time TensorRT library version // information {Maj, Min, Patch}. string GetLinkedTensorRTVersion();