Move TRT utility functions from convert_nodes to convert/utils
This commit is contained in:
parent
aff962d3b0
commit
2cf2298880
@ -200,18 +200,6 @@ int64 TFAttrs::get<int64>(const string& key) const {
|
|||||||
return this->at(key)->i();
|
return this->at(key)->i();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorShapeType>
|
|
||||||
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
|
||||||
bool ignore_first_dim) {
|
|
||||||
nvinfer1::Dims trt_dims;
|
|
||||||
const int offset = (ignore_first_dim ? 1 : 0);
|
|
||||||
for (int i = offset; i < shape.dims(); i++) {
|
|
||||||
trt_dims.d[i - offset] = shape.dim_size(i);
|
|
||||||
}
|
|
||||||
trt_dims.nbDims = shape.dims() - offset;
|
|
||||||
return trt_dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Container>
|
template <typename Container>
|
||||||
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
|
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
|
||||||
bool ignore_first_dim = false) {
|
bool ignore_first_dim = false) {
|
||||||
@ -314,21 +302,6 @@ Status ValidateTensorProperties(const string& producer_node_type,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString(const nvinfer1::DimensionType type) {
|
|
||||||
switch (type) {
|
|
||||||
case nvinfer1::DimensionType::kSPATIAL:
|
|
||||||
return "kSPATIAL";
|
|
||||||
case nvinfer1::DimensionType::kCHANNEL:
|
|
||||||
return "kCHANNEL";
|
|
||||||
case nvinfer1::DimensionType::kINDEX:
|
|
||||||
return "kINDEX";
|
|
||||||
case nvinfer1::DimensionType::kSEQUENCE:
|
|
||||||
return "kSEQUENCE";
|
|
||||||
default:
|
|
||||||
return StrCat(static_cast<int>(type), "=unknown");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString(const nvinfer1::DataType trt_dtype) {
|
string DebugString(const nvinfer1::DataType trt_dtype) {
|
||||||
switch (trt_dtype) {
|
switch (trt_dtype) {
|
||||||
case nvinfer1::DataType::kFLOAT:
|
case nvinfer1::DataType::kFLOAT:
|
||||||
@ -344,20 +317,6 @@ string DebugString(const nvinfer1::DataType trt_dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString(const nvinfer1::Dims& dims) {
|
|
||||||
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
|
|
||||||
for (int i = 0; i < dims.nbDims; ++i) {
|
|
||||||
StrAppend(&out, dims.d[i]);
|
|
||||||
if (VLOG_IS_ON(2)) {
|
|
||||||
StrAppend(&out, "[", DebugString(dims.type[i]), "],");
|
|
||||||
} else {
|
|
||||||
StrAppend(&out, ",");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StrAppend(&out, ")");
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString(const nvinfer1::Permutation& permutation, int len) {
|
string DebugString(const nvinfer1::Permutation& permutation, int len) {
|
||||||
string out = "nvinfer1::Permutation(";
|
string out = "nvinfer1::Permutation(";
|
||||||
for (int i = 0; i < len; ++i) {
|
for (int i = 0; i < len; ++i) {
|
||||||
@ -581,14 +540,6 @@ inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) {
|
|||||||
return dims;
|
return dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool HasStaticShape(const nvinfer1::Dims& dims) {
|
|
||||||
if (dims.nbDims < 0) return false;
|
|
||||||
for (int d = 0; d < dims.nbDims; ++d) {
|
|
||||||
if (dims.d[d] < 0) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t Prod(const nvinfer1::Dims& dims) {
|
int64_t Prod(const nvinfer1::Dims& dims) {
|
||||||
int64_t count = 1;
|
int64_t count = 1;
|
||||||
for (int d = 0; d < dims.nbDims; ++d) {
|
for (int d = 0; d < dims.nbDims; ++d) {
|
||||||
|
@ -42,14 +42,6 @@ namespace tensorrt {
|
|||||||
namespace convert {
|
namespace convert {
|
||||||
using ::stream_executor::port::StatusOr;
|
using ::stream_executor::port::StatusOr;
|
||||||
|
|
||||||
#define IS_TRT_VERSION_GE(major, minor, patch, build) \
|
|
||||||
((NV_TENSORRT_MAJOR > major) || \
|
|
||||||
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
|
|
||||||
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
|
||||||
NV_TENSORRT_PATCH > patch) || \
|
|
||||||
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
|
||||||
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
|
|
||||||
|
|
||||||
struct EngineConnection {
|
struct EngineConnection {
|
||||||
// Constructs a non-control edge.
|
// Constructs a non-control edge.
|
||||||
EngineConnection(const string& outside, int out_id, int out_port,
|
EngineConnection(const string& outside, int out_id, int out_port,
|
||||||
|
@ -17,9 +17,15 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
|
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
#include "third_party/tensorrt/NvInfer.h"
|
||||||
|
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
|
||||||
@ -45,6 +51,41 @@ Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name);
|
|||||||
|
|
||||||
Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
|
Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
|
||||||
|
#define IS_TRT_VERSION_GE(major, minor, patch, build) \
|
||||||
|
((NV_TENSORRT_MAJOR > major) || \
|
||||||
|
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
|
||||||
|
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
||||||
|
NV_TENSORRT_PATCH > patch) || \
|
||||||
|
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
||||||
|
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::DimensionType type);
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::Dims& dims);
|
||||||
|
|
||||||
|
inline bool HasStaticShape(const nvinfer1::Dims& dims) {
|
||||||
|
if (dims.nbDims < 0) return false;
|
||||||
|
for (int d = 0; d < dims.nbDims; ++d) {
|
||||||
|
if (dims.d[d] < 0) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TensorShapeType>
|
||||||
|
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
||||||
|
bool ignore_first_dim) {
|
||||||
|
nvinfer1::Dims trt_dims;
|
||||||
|
const int offset = (ignore_first_dim ? 1 : 0);
|
||||||
|
for (int i = offset; i < shape.dims(); i++) {
|
||||||
|
trt_dims.d[i - offset] = shape.dim_size(i);
|
||||||
|
}
|
||||||
|
trt_dims.nbDims = shape.dims() - offset;
|
||||||
|
return trt_dims;
|
||||||
|
}
|
||||||
|
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user