Move all DebugString from convert_nodes to utils.
This commit is contained in:
parent
ed84b09ce1
commit
fef1b9b81f
@ -302,37 +302,6 @@ Status ValidateTensorProperties(const string& producer_node_type,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::DataType trt_dtype) {
|
||||
switch (trt_dtype) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return "kFLOAT";
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return "kHALF";
|
||||
case nvinfer1::DataType::kINT8:
|
||||
return "kINT8";
|
||||
case nvinfer1::DataType::kINT32:
|
||||
return "kINT32";
|
||||
default:
|
||||
return "Invalid TRT data type";
|
||||
}
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::Permutation& permutation, int len) {
|
||||
string out = "nvinfer1::Permutation(";
|
||||
for (int i = 0; i < len; ++i) {
|
||||
StrAppend(&out, permutation.order[i], ",");
|
||||
}
|
||||
StrAppend(&out, ")");
|
||||
return out;
|
||||
}
|
||||
|
||||
string DebugString(const nvinfer1::ITensor& tensor) {
|
||||
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
|
||||
", name=", tensor.getName(),
|
||||
", dtype=", DebugString(tensor.getType()),
|
||||
", dims=", DebugString(tensor.getDimensions()), ")");
|
||||
}
|
||||
|
||||
Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||
const TRT_TensorOrWeights& operand_r,
|
||||
const bool check_feasibility,
|
||||
@ -683,8 +652,9 @@ size_t TRT_ShapedWeights::size_bytes() const {
|
||||
}
|
||||
|
||||
string TRT_ShapedWeights::DebugString() const {
|
||||
return StrCat("TRT_ShapedWeights(shape=", convert::DebugString(shape_),
|
||||
", type=", convert::DebugString(type_),
|
||||
return StrCat("TRT_ShapedWeights(shape=",
|
||||
tensorflow::tensorrt::DebugString(shape_),
|
||||
", type=", tensorflow::tensorrt::DebugString(type_),
|
||||
", values=", reinterpret_cast<uintptr_t>(GetValues()), ")");
|
||||
}
|
||||
|
||||
@ -809,7 +779,7 @@ nvinfer1::Dims TRT_TensorOrWeights::GetTrtDims() const {
|
||||
string TRT_TensorOrWeights::DebugString() const {
|
||||
string output = "TRT_TensorOrWeights(type=";
|
||||
if (is_tensor()) {
|
||||
StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
|
||||
StrAppend(&output, "tensor=", tensorflow::tensorrt::DebugString(*tensor()),
|
||||
", batch_size=", batch_size_);
|
||||
} else {
|
||||
StrAppend(&output, "weights=", weights_.DebugString());
|
||||
|
@ -156,11 +156,6 @@ class OutputEdgeValidator {
|
||||
bool operator()(const Edge* out_edge) const;
|
||||
};
|
||||
|
||||
string DebugString(const nvinfer1::DimensionType type);
|
||||
string DebugString(const nvinfer1::DataType trt_dtype);
|
||||
string DebugString(const nvinfer1::Dims& dims);
|
||||
string DebugString(const nvinfer1::Permutation& permutation, int len);
|
||||
string DebugString(const nvinfer1::ITensor& tensor);
|
||||
int64_t TrtWeightDimsNumElements(const nvinfer1::Dims& dims);
|
||||
int64_t TrtTensorDimsNumElements(const nvinfer1::Dims& dims);
|
||||
|
||||
|
@ -70,8 +70,10 @@ struct VectorTensorShapeHasher {
|
||||
NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
|
||||
|
||||
string DebugString(const nvinfer1::DimensionType type);
|
||||
|
||||
string DebugString(const nvinfer1::Dims& dims);
|
||||
string DebugString(const nvinfer1::DataType trt_dtype);
|
||||
string DebugString(const nvinfer1::Permutation& permutation, int len);
|
||||
string DebugString(const nvinfer1::ITensor& tensor);
|
||||
|
||||
inline bool HasStaticShape(const nvinfer1::Dims& dims) {
|
||||
if (dims.nbDims < 0) return false;
|
||||
|
Loading…
x
Reference in New Issue
Block a user