diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 94c07e89778..79c1804830e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -302,37 +302,6 @@ Status ValidateTensorProperties(const string& producer_node_type, return Status::OK(); } -string DebugString(const nvinfer1::DataType trt_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - return "kFLOAT"; - case nvinfer1::DataType::kHALF: - return "kHALF"; - case nvinfer1::DataType::kINT8: - return "kINT8"; - case nvinfer1::DataType::kINT32: - return "kINT32"; - default: - return "Invalid TRT data type"; - } -} - -string DebugString(const nvinfer1::Permutation& permutation, int len) { - string out = "nvinfer1::Permutation("; - for (int i = 0; i < len; ++i) { - StrAppend(&out, permutation.order[i], ","); - } - StrAppend(&out, ")"); - return out; -} - -string DebugString(const nvinfer1::ITensor& tensor) { - return StrCat("nvinfer1::ITensor(@", reinterpret_cast(&tensor), - ", name=", tensor.getName(), - ", dtype=", DebugString(tensor.getType()), - ", dims=", DebugString(tensor.getDimensions()), ")"); -} - Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r, const bool check_feasibility, @@ -683,8 +652,9 @@ size_t TRT_ShapedWeights::size_bytes() const { } string TRT_ShapedWeights::DebugString() const { - return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_), - ", type=", convert::DebugString(type_), + return StrCat("TRT_ShapedWeights(shape=", + tensorflow::tensorrt::DebugString(shape_), + ", type=", tensorflow::tensorrt::DebugString(type_), ", values=", reinterpret_cast(GetValues()), ")"); } @@ -809,7 +779,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const { string TRT_TensorOrWeights::DebugString() const { string output = "TRT_TensorOrWeights(type="; if (is_tensor()) { - StrAppend(&output, "tensor=", convert::DebugString(*tensor()), + StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()), ", batch_size=", batch_size_); } else { StrAppend(&output, "weights=", weights_.DebugString()); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 91d2939353f..a9f579c9ed7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -156,11 +156,6 @@ class OutputEdgeValidator { bool operator()(const Edge* out_edge) const; }; -string DebugString(const nvinfer1::DimensionType type); -string DebugString(const nvinfer1::DataType trt_dtype); -string DebugString(const nvinfer1::Dims& dims); -string DebugString(const nvinfer1::Permutation& permutation, int len); -string DebugString(const nvinfer1::ITensor& tensor); int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims); int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims); diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 1b088025686..af7c2623ed2 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -70,8 +70,10 @@ struct VectorTensorShapeHasher { NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) string DebugString(const nvinfer1::DimensionType type); - string DebugString(const nvinfer1::Dims& dims); +string DebugString(const nvinfer1::DataType trt_dtype); +string DebugString(const nvinfer1::Permutation& permutation, int len); +string DebugString(const nvinfer1::ITensor& tensor); inline bool HasStaticShape(const nvinfer1::Dims& dims) { if (dims.nbDims < 0) return false;