Merge pull request #35233 from tfeher:trt_move_util_func
PiperOrigin-RevId: 286440099 Change-Id: I2c2ee35b714fb3b2340fe15907f744196e9d3744
This commit is contained in:
commit
988b0eea45
@ -500,7 +500,8 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib_proto_parsing",
|
"//tensorflow/core:lib_proto_parsing",
|
||||||
],
|
"//tensorflow/core:lib",
|
||||||
|
] + if_tensorrt([":tensorrt_lib"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_proto_library(
|
tf_proto_library(
|
||||||
|
@ -200,18 +200,6 @@ int64 TFAttrs::get<int64>(const string& key) const {
|
|||||||
return this->at(key)->i();
|
return this->at(key)->i();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorShapeType>
|
|
||||||
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
|
||||||
bool ignore_first_dim) {
|
|
||||||
nvinfer1::Dims trt_dims;
|
|
||||||
const int offset = (ignore_first_dim ? 1 : 0);
|
|
||||||
for (int i = offset; i < shape.dims(); i++) {
|
|
||||||
trt_dims.d[i - offset] = shape.dim_size(i);
|
|
||||||
}
|
|
||||||
trt_dims.nbDims = shape.dims() - offset;
|
|
||||||
return trt_dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Container>
|
template <typename Container>
|
||||||
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
|
Status TensorShapeArrayToTrtDims(const Container& shape, nvinfer1::Dims* out,
|
||||||
bool ignore_first_dim = false) {
|
bool ignore_first_dim = false) {
|
||||||
@ -314,66 +302,6 @@ Status ValidateTensorProperties(const string& producer_node_type,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugString(const nvinfer1::DimensionType type) {
|
|
||||||
switch (type) {
|
|
||||||
case nvinfer1::DimensionType::kSPATIAL:
|
|
||||||
return "kSPATIAL";
|
|
||||||
case nvinfer1::DimensionType::kCHANNEL:
|
|
||||||
return "kCHANNEL";
|
|
||||||
case nvinfer1::DimensionType::kINDEX:
|
|
||||||
return "kINDEX";
|
|
||||||
case nvinfer1::DimensionType::kSEQUENCE:
|
|
||||||
return "kSEQUENCE";
|
|
||||||
default:
|
|
||||||
return StrCat(static_cast<int>(type), "=unknown");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString(const nvinfer1::DataType trt_dtype) {
|
|
||||||
switch (trt_dtype) {
|
|
||||||
case nvinfer1::DataType::kFLOAT:
|
|
||||||
return "kFLOAT";
|
|
||||||
case nvinfer1::DataType::kHALF:
|
|
||||||
return "kHALF";
|
|
||||||
case nvinfer1::DataType::kINT8:
|
|
||||||
return "kINT8";
|
|
||||||
case nvinfer1::DataType::kINT32:
|
|
||||||
return "kINT32";
|
|
||||||
default:
|
|
||||||
return "Invalid TRT data type";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString(const nvinfer1::Dims& dims) {
|
|
||||||
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
|
|
||||||
for (int i = 0; i < dims.nbDims; ++i) {
|
|
||||||
StrAppend(&out, dims.d[i]);
|
|
||||||
if (VLOG_IS_ON(2)) {
|
|
||||||
StrAppend(&out, "[", DebugString(dims.type[i]), "],");
|
|
||||||
} else {
|
|
||||||
StrAppend(&out, ",");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
StrAppend(&out, ")");
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString(const nvinfer1::Permutation& permutation, int len) {
|
|
||||||
string out = "nvinfer1::Permutation(";
|
|
||||||
for (int i = 0; i < len; ++i) {
|
|
||||||
StrAppend(&out, permutation.order[i], ",");
|
|
||||||
}
|
|
||||||
StrAppend(&out, ")");
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString(const nvinfer1::ITensor& tensor) {
|
|
||||||
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
|
|
||||||
", name=", tensor.getName(),
|
|
||||||
", dtype=", DebugString(tensor.getType()),
|
|
||||||
", dims=", DebugString(tensor.getDimensions()), ")");
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||||
const TRT_TensorOrWeights& operand_r,
|
const TRT_TensorOrWeights& operand_r,
|
||||||
const bool check_feasibility,
|
const bool check_feasibility,
|
||||||
@ -581,14 +509,6 @@ inline nvinfer1::Dims GetTrtDimsForTensor(const Tensor& tensor) {
|
|||||||
return dims;
|
return dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool HasStaticShape(const nvinfer1::Dims& dims) {
|
|
||||||
if (dims.nbDims < 0) return false;
|
|
||||||
for (int d = 0; d < dims.nbDims; ++d) {
|
|
||||||
if (dims.d[d] < 0) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t Prod(const nvinfer1::Dims& dims) {
|
int64_t Prod(const nvinfer1::Dims& dims) {
|
||||||
int64_t count = 1;
|
int64_t count = 1;
|
||||||
for (int d = 0; d < dims.nbDims; ++d) {
|
for (int d = 0; d < dims.nbDims; ++d) {
|
||||||
@ -732,9 +652,10 @@ size_t TRT_ShapedWeights::size_bytes() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
string TRT_ShapedWeights::DebugString() const {
|
string TRT_ShapedWeights::DebugString() const {
|
||||||
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
|
return StrCat(
|
||||||
", type=", convert::DebugString(type_),
|
"TRT_ShapedWeights(shape=", tensorflow::tensorrt::DebugString(shape_),
|
||||||
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
|
", type=", tensorflow::tensorrt::DebugString(type_),
|
||||||
|
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
|
||||||
}
|
}
|
||||||
|
|
||||||
// A fake ITensor implementation used to check whether the TF-TRT converter can
|
// A fake ITensor implementation used to check whether the TF-TRT converter can
|
||||||
@ -858,7 +779,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
|
|||||||
string TRT_TensorOrWeights::DebugString() const {
|
string TRT_TensorOrWeights::DebugString() const {
|
||||||
string output = "TRT_TensorOrWeights(type=";
|
string output = "TRT_TensorOrWeights(type=";
|
||||||
if (is_tensor()) {
|
if (is_tensor()) {
|
||||||
StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
|
StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()),
|
||||||
", batch_size=", batch_size_);
|
", batch_size=", batch_size_);
|
||||||
} else {
|
} else {
|
||||||
StrAppend(&output, "weights=", weights_.DebugString());
|
StrAppend(&output, "weights=", weights_.DebugString());
|
||||||
@ -2234,23 +2155,22 @@ Status ConvertConv2DHelper(OpConverterParams* params, int group,
|
|||||||
// argument output_shape and thus the TRT output shape could be wrong
|
// argument output_shape and thus the TRT output shape could be wrong
|
||||||
// in case of strides>1.
|
// in case of strides>1.
|
||||||
if (is_conv2d_backprop_input) {
|
if (is_conv2d_backprop_input) {
|
||||||
auto tf_output_shape = backprop_output_size.GetTrtDims();
|
auto tf_output_shape =
|
||||||
|
static_cast<int*>(backprop_output_size.weights().GetValues());
|
||||||
nvinfer1::Dims trt_output_shape = output_tensor->getDimensions();
|
nvinfer1::Dims trt_output_shape = output_tensor->getDimensions();
|
||||||
// What determines the padding size is the difference between the given
|
// What determines the padding size is the difference between the given
|
||||||
// input_sizes (tf_output_shape) and TRT computed size.
|
// input_sizes (tf_output_shape) and TRT computed size.
|
||||||
const int height_diff =
|
const int height_diff = tf_output_shape[h_index] - trt_output_shape.d[1];
|
||||||
tf_output_shape.d[h_index - 1] - trt_output_shape.d[1];
|
const int width_diff = tf_output_shape[w_index] - trt_output_shape.d[2];
|
||||||
const int width_diff =
|
|
||||||
tf_output_shape.d[w_index - 1] - trt_output_shape.d[2];
|
|
||||||
if ((height_diff < 0) || (width_diff < 0)) {
|
if ((height_diff < 0) || (width_diff < 0)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"input_sizes argument of Conv2DBackprop (i.e. output_shape argument "
|
"input_sizes argument of Conv2DBackprop (i.e. output_shape argument "
|
||||||
"of conv2d_transpose)",
|
"of conv2d_transpose) ",
|
||||||
"is too small for the given out_backprop argument of Conv2DBackprop "
|
"is too small for the given out_backprop argument of Conv2DBackprop "
|
||||||
"(i.e. input argument of conv2d_transpose).",
|
"(i.e. input argument of conv2d_transpose). Expect: ",
|
||||||
"(", tf_output_shape.d[h_index - 1], ", ",
|
"(", tf_output_shape[h_index], ", ", tf_output_shape[w_index],
|
||||||
tf_output_shape.d[w_index - 1], ") >= ", "(", trt_output_shape.d[1],
|
") >= ", "(", trt_output_shape.d[1], ", ", trt_output_shape.d[2],
|
||||||
", ", trt_output_shape.d[2], ")", node_def.name());
|
") for op ", node_def.name());
|
||||||
}
|
}
|
||||||
// Only add a padding layer if padding sizes are larger than 0
|
// Only add a padding layer if padding sizes are larger than 0
|
||||||
if ((height_diff > 0) || (width_diff > 0)) {
|
if ((height_diff > 0) || (width_diff > 0)) {
|
||||||
|
@ -42,14 +42,6 @@ namespace tensorrt {
|
|||||||
namespace convert {
|
namespace convert {
|
||||||
using ::stream_executor::port::StatusOr;
|
using ::stream_executor::port::StatusOr;
|
||||||
|
|
||||||
#define IS_TRT_VERSION_GE(major, minor, patch, build) \
|
|
||||||
((NV_TENSORRT_MAJOR > major) || \
|
|
||||||
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
|
|
||||||
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
|
||||||
NV_TENSORRT_PATCH > patch) || \
|
|
||||||
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
|
||||||
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
|
|
||||||
|
|
||||||
struct EngineConnection {
|
struct EngineConnection {
|
||||||
// Constructs a non-control edge.
|
// Constructs a non-control edge.
|
||||||
EngineConnection(const string& outside, int out_id, int out_port,
|
EngineConnection(const string& outside, int out_id, int out_port,
|
||||||
@ -164,11 +156,6 @@ class OutputEdgeValidator {
|
|||||||
bool operator()(const Edge* out_edge) const;
|
bool operator()(const Edge* out_edge) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
string DebugString(const nvinfer1::DimensionType type);
|
|
||||||
string DebugString(const nvinfer1::DataType trt_dtype);
|
|
||||||
string DebugString(const nvinfer1::Dims& dims);
|
|
||||||
string DebugString(const nvinfer1::Permutation& permutation, int len);
|
|
||||||
string DebugString(const nvinfer1::ITensor& tensor);
|
|
||||||
int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
|
int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
|
||||||
int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
|
int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
|
||||||
|
|
||||||
|
@ -3856,7 +3856,7 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Ok.
|
// Ok.
|
||||||
const int kConv2DOKCases = 7;
|
const int kConv2DOKCases = 9;
|
||||||
TestParams ok_params[kConv2DOKCases] = {
|
TestParams ok_params[kConv2DOKCases] = {
|
||||||
// Basic
|
// Basic
|
||||||
TestParams{/*input_dims=*/{1, 2, 3},
|
TestParams{/*input_dims=*/{1, 2, 3},
|
||||||
@ -3978,8 +3978,10 @@ TEST_F(OpConverterTest, ConvertConv2D) {
|
|||||||
AddTestWeights<float>("weights", ok_params[i].filter_dims,
|
AddTestWeights<float>("weights", ok_params[i].filter_dims,
|
||||||
ok_params[i].filter);
|
ok_params[i].filter);
|
||||||
if (ok_params[i].is_conv2d_backprop_input) {
|
if (ok_params[i].is_conv2d_backprop_input) {
|
||||||
AddTestWeights<float>("input_sizes", ok_params[i].expected_output_dims,
|
std::vector<int> tf_input_sizes = ok_params[i].expected_output_dims;
|
||||||
ok_params[i].expected_output);
|
tf_input_sizes.insert(tf_input_sizes.begin(), 1); // Add batch dimension.
|
||||||
|
QCHECK_EQ(4, tf_input_sizes.size());
|
||||||
|
AddTestWeights<int>("input_sizes", {4}, tf_input_sizes);
|
||||||
}
|
}
|
||||||
RunValidationAndConversion(node_def);
|
RunValidationAndConversion(node_def);
|
||||||
TRT_TensorOrWeights output;
|
TRT_TensorOrWeights output;
|
||||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
@ -51,5 +53,71 @@ Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
using absl::StrAppend;
|
||||||
|
using absl::StrCat;
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::DimensionType type) {
|
||||||
|
switch (type) {
|
||||||
|
case nvinfer1::DimensionType::kSPATIAL:
|
||||||
|
return "kSPATIAL";
|
||||||
|
case nvinfer1::DimensionType::kCHANNEL:
|
||||||
|
return "kCHANNEL";
|
||||||
|
case nvinfer1::DimensionType::kINDEX:
|
||||||
|
return "kINDEX";
|
||||||
|
case nvinfer1::DimensionType::kSEQUENCE:
|
||||||
|
return "kSEQUENCE";
|
||||||
|
default:
|
||||||
|
return StrCat(static_cast<int>(type), "=unknown");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::Dims& dims) {
|
||||||
|
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
|
||||||
|
for (int i = 0; i < dims.nbDims; ++i) {
|
||||||
|
StrAppend(&out, dims.d[i]);
|
||||||
|
if (VLOG_IS_ON(2)) {
|
||||||
|
StrAppend(&out, "[", DebugString(dims.type[i]), "],");
|
||||||
|
} else {
|
||||||
|
StrAppend(&out, ",");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
StrAppend(&out, ")");
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::DataType trt_dtype) {
|
||||||
|
switch (trt_dtype) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
return "kFLOAT";
|
||||||
|
case nvinfer1::DataType::kHALF:
|
||||||
|
return "kHALF";
|
||||||
|
case nvinfer1::DataType::kINT8:
|
||||||
|
return "kINT8";
|
||||||
|
case nvinfer1::DataType::kINT32:
|
||||||
|
return "kINT32";
|
||||||
|
default:
|
||||||
|
return "Invalid TRT data type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::Permutation& permutation, int len) {
|
||||||
|
string out = "nvinfer1::Permutation(";
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
StrAppend(&out, permutation.order[i], ",");
|
||||||
|
}
|
||||||
|
StrAppend(&out, ")");
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::ITensor& tensor) {
|
||||||
|
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
|
||||||
|
", name=", tensor.getName(),
|
||||||
|
", dtype=", DebugString(tensor.getType()),
|
||||||
|
", dims=", DebugString(tensor.getDimensions()), ")");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -17,9 +17,15 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
|
#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
#include "third_party/tensorrt/NvInfer.h"
|
||||||
|
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tensorrt {
|
namespace tensorrt {
|
||||||
|
|
||||||
@ -45,6 +51,51 @@ Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name);
|
|||||||
|
|
||||||
Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
|
Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
|
||||||
|
|
||||||
|
// Define a hash function for vector<TensorShape> because it is used as the key
|
||||||
|
// for the engine cache.
|
||||||
|
struct VectorTensorShapeHasher {
|
||||||
|
std::size_t operator()(const std::vector<TensorShape>& key) const {
|
||||||
|
return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
|
||||||
|
#define IS_TRT_VERSION_GE(major, minor, patch, build) \
|
||||||
|
((NV_TENSORRT_MAJOR > major) || \
|
||||||
|
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
|
||||||
|
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
||||||
|
NV_TENSORRT_PATCH > patch) || \
|
||||||
|
(NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
|
||||||
|
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
|
||||||
|
|
||||||
|
string DebugString(const nvinfer1::DimensionType type);
|
||||||
|
string DebugString(const nvinfer1::Dims& dims);
|
||||||
|
string DebugString(const nvinfer1::DataType trt_dtype);
|
||||||
|
string DebugString(const nvinfer1::Permutation& permutation, int len);
|
||||||
|
string DebugString(const nvinfer1::ITensor& tensor);
|
||||||
|
|
||||||
|
inline bool HasStaticShape(const nvinfer1::Dims& dims) {
|
||||||
|
if (dims.nbDims < 0) return false;
|
||||||
|
for (int d = 0; d < dims.nbDims; ++d) {
|
||||||
|
if (dims.d[d] < 0) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename TensorShapeType>
|
||||||
|
inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
|
||||||
|
bool ignore_first_dim) {
|
||||||
|
nvinfer1::Dims trt_dims;
|
||||||
|
const int offset = (ignore_first_dim ? 1 : 0);
|
||||||
|
for (int i = offset; i < shape.dims(); i++) {
|
||||||
|
trt_dims.d[i - offset] = shape.dim_size(i);
|
||||||
|
}
|
||||||
|
trt_dims.nbDims = shape.dims() - offset;
|
||||||
|
return trt_dims;
|
||||||
|
}
|
||||||
|
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||||
|
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -114,14 +114,6 @@ class LRUCache {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Define a hash function for vector<TensorShape> because it is used as the key
|
|
||||||
// for the engine cache.
|
|
||||||
struct VectorTensorShapeHasher {
|
|
||||||
std::size_t operator()(const std::vector<TensorShape>& key) const {
|
|
||||||
return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#if GOOGLE_TENSORRT
|
#if GOOGLE_TENSORRT
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user