Add definition of DebugString to utils.cc
This commit is contained in:
parent
549999490d
commit
ed84b09ce1
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tensorrt {
|
||||
@ -51,5 +53,72 @@ Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
#if GOOGLE_CUDA && GOOGLE_TENSORRT
|
||||
string DebugString(const nvinfer1::DimensionType type) {
|
||||
switch (type) {
|
||||
case nvinfer1::DimensionType::kSPATIAL:
|
||||
return "kSPATIAL";
|
||||
case nvinfer1::DimensionType::kCHANNEL:
|
||||
return "kCHANNEL";
|
||||
case nvinfer1::DimensionType::kINDEX:
|
||||
return "kINDEX";
|
||||
case nvinfer1::DimensionType::kSEQUENCE:
|
||||
return "kSEQUENCE";
|
||||
default:
|
||||
return StrCat(static_cast<int>(type), "=unknown");
|
||||
}
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::Dims& dims) {
|
||||
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
|
||||
for (int i = 0; i < dims.nbDims; ++i) {
|
||||
StrAppend(&out, dims.d[i]);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
StrAppend(&out, "[", DebugString(dims.type[i]), "],");
|
||||
} else {
|
||||
StrAppend(&out, ",");
|
||||
}
|
||||
}
|
||||
StrAppend(&out, ")");
|
||||
return out;
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::DataType trt_dtype) {
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return "kFLOAT";
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return "kHALF";
|
||||
case nvinfer1::DataType::kINT8:
|
||||
return "kINT8";
|
||||
case nvinfer1::DataType::kINT32:
|
||||
return "kINT32";
|
||||
default:
|
||||
return "Invalid TRT data type";
|
||||
}
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::Permutation& permutation, int len) {
|
||||
string out = "nvinfer1::Permutation(";
|
||||
for (int i = 0; i < len; ++i) {
|
||||
StrAppend(&out, permutation.order[i], ",");
|
||||
}
|
||||
StrAppend(&out, ")");
|
||||
return out;
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::ITensor& tensor) {
|
||||
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
|
||||
", name=", tensor.getName(),
|
||||
", dtype=", DebugString(tensor.getType()),
|
||||
", dims=",
|
||||
tensorflow::tensorrt::DebugString(tensor.getDimensions()), ")");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user