Move all DebugString from convert_nodes to utils.

This commit is contained in:
Tamas Bela Feher 2019-12-18 17:46:20 +01:00
parent ed84b09ce1
commit fef1b9b81f
3 changed files with 7 additions and 40 deletions

View File

@ -302,37 +302,6 @@ Status ValidateTensorProperties(const string& producer_node_type,
return Status::OK();
}
string DebugString(const nvinfer1::DataType trt_dtype) {
switch (trt_dtype) {
case nvinfer1::DataType::kFLOAT:
return "kFLOAT";
case nvinfer1::DataType::kHALF:
return "kHALF";
case nvinfer1::DataType::kINT8:
return "kINT8";
case nvinfer1::DataType::kINT32:
return "kINT32";
default:
return "Invalid TRT data type";
}
}
string DebugString(const nvinfer1::Permutation& permutation, int len) {
string out = "nvinfer1::Permutation(";
for (int i = 0; i < len; ++i) {
StrAppend(&out, permutation.order[i], ",");
}
StrAppend(&out, ")");
return out;
}
string DebugString(const nvinfer1::ITensor& tensor) {
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
", name=", tensor.getName(),
", dtype=", DebugString(tensor.getType()),
", dims=", DebugString(tensor.getDimensions()), ")");
}
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r,
const bool check_feasibility,
@ -683,8 +652,9 @@ size_t TRT_ShapedWeights::size_bytes() const {
}
string TRT_ShapedWeights::DebugString() const {
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
", type=", convert::DebugString(type_),
return StrCat("TRT_ShapedWeights(shape=",
tensorflow::tensorrt::DebugString(shape_),
", type=", tensorflow::tensorrt::DebugString(type_),
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
}
@ -809,7 +779,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
string TRT_TensorOrWeights::DebugString() const {
string output = "TRT_TensorOrWeights(type=";
if (is_tensor()) {
StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()),
", batch_size=", batch_size_);
} else {
StrAppend(&output, "weights=", weights_.DebugString());

View File

@ -156,11 +156,6 @@ class OutputEdgeValidator {
bool operator()(const Edge* out_edge) const;
};
string DebugString(const nvinfer1::DimensionType type);
string DebugString(const nvinfer1::DataType trt_dtype);
string DebugString(const nvinfer1::Dims& dims);
string DebugString(const nvinfer1::Permutation& permutation, int len);
string DebugString(const nvinfer1::ITensor& tensor);
int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);

View File

@ -70,8 +70,10 @@ struct VectorTensorShapeHasher {
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
string DebugString(const nvinfer1::DimensionType type);
string DebugString(const nvinfer1::Dims& dims);
string DebugString(const nvinfer1::DataType trt_dtype);
string DebugString(const nvinfer1::Permutation& permutation, int len);
string DebugString(const nvinfer1::ITensor& tensor);
inline bool HasStaticShape(const nvinfer1::Dims& dims) {
if (dims.nbDims < 0) return false;