diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index ca21c193d63..ea5ed526ddd 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tensorrt { @@ -51,5 +53,72 @@ Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) { return Status::OK(); } +using absl::StrAppend; +using absl::StrCat; + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +string DebugString(const nvinfer1::DimensionType type) { + switch (type) { + case nvinfer1::DimensionType::kSPATIAL: + return "kSPATIAL"; + case nvinfer1::DimensionType::kCHANNEL: + return "kCHANNEL"; + case nvinfer1::DimensionType::kINDEX: + return "kINDEX"; + case nvinfer1::DimensionType::kSEQUENCE: + return "kSEQUENCE"; + default: + return StrCat(static_cast(type), "=unknown"); + } +} + +string DebugString(const nvinfer1::Dims& dims) { + string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d="); + for (int i = 0; i < dims.nbDims; ++i) { + StrAppend(&out, dims.d[i]); + if (VLOG_IS_ON(2)) { + StrAppend(&out, "[", DebugString(dims.type[i]), "],"); + } else { + StrAppend(&out, ","); + } + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::DataType trt_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + return "kFLOAT"; + case nvinfer1::DataType::kHALF: + return "kHALF"; + case nvinfer1::DataType::kINT8: + return "kINT8"; + case nvinfer1::DataType::kINT32: + return "kINT32"; + default: + return "Invalid TRT data type"; + } +} + +string DebugString(const nvinfer1::Permutation& permutation, int len) { + string out = "nvinfer1::Permutation("; + for (int i = 0; i < len; ++i) { + StrAppend(&out, permutation.order[i], ","); + } + StrAppend(&out, ")"); + return out; +} + +string DebugString(const nvinfer1::ITensor& tensor) { + return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), + ", name=", tensor.getName(), + ", dtype=", DebugString(tensor.getType()), + ", dims=", + tensorflow::tensorrt::DebugString(tensor.getDimensions()), ")"); +} + +#endif + } // namespace tensorrt } // namespace tensorflow