[TF:TRT] Add utilities for converting between TF types and TRT types.
PiperOrigin-RevId: 312087947 Change-Id: Ie4c47ab5c6aae97af5a83bba06e3de0637752ecf
This commit is contained in:
parent
50fcac47a2
commit
55aee9e550
@ -137,30 +137,18 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) {
|
nvinfer1::DataType TfDataTypeToTrt(DataType tf_type) {
|
||||||
switch (tf_dtype) {
|
nvinfer1::DataType trt_type;
|
||||||
case DT_FLOAT:
|
Status status = TfTypeToTrtType(tf_type, &trt_type);
|
||||||
return nvinfer1::DataType::kFLOAT;
|
EXPECT_EQ(status, Status::OK());
|
||||||
case DT_HALF:
|
return trt_type;
|
||||||
return nvinfer1::DataType::kHALF;
|
|
||||||
case DT_INT32:
|
|
||||||
return nvinfer1::DataType::kINT32;
|
|
||||||
default:
|
|
||||||
QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
|
DataType TrtDataTypeToTf(nvinfer1::DataType trt_type) {
|
||||||
switch (trt_dtype) {
|
DataType tf_type;
|
||||||
case nvinfer1::DataType::kFLOAT:
|
Status status = TrtTypeToTfType(trt_type, &tf_type);
|
||||||
return DT_FLOAT;
|
EXPECT_EQ(status, Status::OK());
|
||||||
case nvinfer1::DataType::kHALF:
|
return tf_type;
|
||||||
return DT_HALF;
|
|
||||||
case nvinfer1::DataType::kINT32:
|
|
||||||
return DT_INT32;
|
|
||||||
default:
|
|
||||||
QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef MakeNodeDef(const string& name, const string& op,
|
NodeDef MakeNodeDef(const string& name, const string& op,
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
@ -185,6 +186,40 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
|
||||||
|
switch (tf_type) {
|
||||||
|
case DT_FLOAT:
|
||||||
|
*trt_type = nvinfer1::DataType::kFLOAT;
|
||||||
|
break;
|
||||||
|
case DT_HALF:
|
||||||
|
*trt_type = nvinfer1::DataType::kHALF;
|
||||||
|
break;
|
||||||
|
case DT_INT32:
|
||||||
|
*trt_type = nvinfer1::DataType::kINT32;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return errors::Internal("Unsupported tensorflow type");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
|
||||||
|
switch (trt_type) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
*tf_type = DT_FLOAT;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
*tf_type = DT_HALF;
|
||||||
|
break;
|
||||||
|
case nvinfer1::DataType::kINT32:
|
||||||
|
*tf_type = DT_INT32;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return errors::Internal("Invalid TRT type");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
|
int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
|
||||||
int n_bindings = engine->getNbBindings();
|
int n_bindings = engine->getNbBindings();
|
||||||
int n_input = 0;
|
int n_input = 0;
|
||||||
|
@ -106,6 +106,9 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
|
|||||||
bool use_implicit_batch, int batch_size,
|
bool use_implicit_batch, int batch_size,
|
||||||
TensorShape& shape);
|
TensorShape& shape);
|
||||||
|
|
||||||
|
Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
|
||||||
|
Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
|
||||||
|
|
||||||
// Returns a string that includes compile time TensorRT library version
|
// Returns a string that includes compile time TensorRT library version
|
||||||
// information {Maj, Min, Patch}.
|
// information {Maj, Min, Patch}.
|
||||||
string GetLinkedTensorRTVersion();
|
string GetLinkedTensorRTVersion();
|
||||||
|
Loading…
Reference in New Issue
Block a user