[TF:TRT] Add utilities for converting between TF types and TRT types.

PiperOrigin-RevId: 312087947
Change-Id: Ie4c47ab5c6aae97af5a83bba06e3de0637752ecf
This commit is contained in:
Bixia Zheng 2020-05-18 08:55:02 -07:00 committed by TensorFlower Gardener
parent 50fcac47a2
commit 55aee9e550
3 changed files with 48 additions and 22 deletions

View File

@ -137,30 +137,18 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
return os;
}
nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) {
switch (tf_dtype) {
case DT_FLOAT:
return nvinfer1::DataType::kFLOAT;
case DT_HALF:
return nvinfer1::DataType::kHALF;
case DT_INT32:
return nvinfer1::DataType::kINT32;
default:
QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype);
}
nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) {
nvinfer1::DataType trt_type;
Status status = TfTypeToTrtType(tf_type, &trt_type);
EXPECT_EQ(status, Status::OK());
return trt_type;
}
DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
return DT_FLOAT;
case nvinfer1::DataType::kHALF:
return DT_HALF;
case nvinfer1::DataType::kINT32:
return DT_INT32;
default:
QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype);
}
DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) {
DataType tf_type;
Status status = TrtTypeToTfType(trt_type, &tf_type);
EXPECT_EQ(status, Status::OK());
return tf_type;
}
NodeDef MakeNodeDef(const string& name, const string& op,

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace tensorrt {
@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
return Status::OK();
}
Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
switch (tf_type) {
case DT_FLOAT:
*trt_type = nvinfer1::DataType::kFLOAT;
break;
case DT_HALF:
*trt_type = nvinfer1::DataType::kHALF;
break;
case DT_INT32:
*trt_type = nvinfer1::DataType::kINT32;
break;
default:
return errors::Internal("Unsupported tensorflow type");
}
return Status::OK();
}
Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
switch (trt_type) {
case nvinfer1::DataType::kFLOAT:
*tf_type = DT_FLOAT;
break;
case nvinfer1::DataType::kHALF:
*tf_type = DT_HALF;
break;
case nvinfer1::DataType::kINT32:
*tf_type = DT_INT32;
break;
default:
return errors::Internal("Invalid TRT type");
}
return Status::OK();
}
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
int n_bindings = engine->getNbBindings();
int n_input = 0;

View File

@ -106,6 +106,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
bool use_implicit_batch, int batch_size,
TensorShape& shape);
Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
// Returns a string that includes compile time TensorRT library version
// information {Maj, Min, Patch}.
string GetLinkedTensorRTVersion();