[TF:TRT] Add utilities for converting between TF types and TRT types.
PiperOrigin-RevId: 312087947 Change-Id: Ie4c47ab5c6aae97af5a83bba06e3de0637752ecf
This commit is contained in:
parent
50fcac47a2
commit
55aee9e550
@ -137,30 +137,18 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) {
|
||||
switch (tf_dtype) {
|
||||
case DT_FLOAT:
|
||||
return nvinfer1::DataType::kFLOAT;
|
||||
case DT_HALF:
|
||||
return nvinfer1::DataType::kHALF;
|
||||
case DT_INT32:
|
||||
return nvinfer1::DataType::kINT32;
|
||||
default:
|
||||
QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype);
|
||||
}
|
||||
nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) {
|
||||
nvinfer1::DataType trt_type;
|
||||
Status status = TfTypeToTrtType(tf_type, &trt_type);
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
return trt_type;
|
||||
}
|
||||
|
||||
DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return DT_FLOAT;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return DT_HALF;
|
||||
case nvinfer1::DataType::kINT32:
|
||||
return DT_INT32;
|
||||
default:
|
||||
QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype);
|
||||
}
|
||||
DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) {
|
||||
DataType tf_type;
|
||||
Status status = TrtTypeToTfType(trt_type, &tf_type);
|
||||
EXPECT_EQ(status, Status::OK());
|
||||
return tf_type;
|
||||
}
|
||||
|
||||
NodeDef MakeNodeDef(const string& name, const string& op,
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
|
||||
switch (tf_type) {
|
||||
case DT_FLOAT:
|
||||
*trt_type = nvinfer1::DataType::kFLOAT;
|
||||
break;
|
||||
case DT_HALF:
|
||||
*trt_type = nvinfer1::DataType::kHALF;
|
||||
break;
|
||||
case DT_INT32:
|
||||
*trt_type = nvinfer1::DataType::kINT32;
|
||||
break;
|
||||
default:
|
||||
return errors::Internal("Unsupported tensorflow type");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
|
||||
switch (trt_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
*tf_type = DT_FLOAT;
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
*tf_type = DT_HALF;
|
||||
break;
|
||||
case nvinfer1::DataType::kINT32:
|
||||
*tf_type = DT_INT32;
|
||||
break;
|
||||
default:
|
||||
return errors::Internal("Invalid TRT type");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
|
||||
int n_bindings = engine->getNbBindings();
|
||||
int n_input = 0;
|
||||
|
@ -106,6 +106,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
|
||||
bool use_implicit_batch, int batch_size,
|
||||
TensorShape& shape);
|
||||
|
||||
Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
|
||||
Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
|
||||
|
||||
// Returns a string that includes compile time TensorRT library version
|
||||
// information {Maj, Min, Patch}.
|
||||
string GetLinkedTensorRTVersion();
|
||||
|
Loading…
Reference in New Issue
Block a user