Add a debug string to AbstractTensorHandle
There's enough information to make it useful, and it's good to have a standard method. Before this TFRT TensorHandles and custom device handles didn't have debug strings. Both may want to override the method eventually. PiperOrigin-RevId: 346876438 Change-Id: I1e679ec891446ad0cb60c03c254fe16d3f3d49ef
This commit is contained in:
parent
0796a95989
commit
e5ffdb506b
tensorflow
c/eager
core/common_runtime/eager
@ -507,6 +507,7 @@ tf_cuda_cc_test(
|
||||
|
||||
cc_library(
|
||||
name = "abstract_tensor_handle",
|
||||
srcs = ["abstract_tensor_handle.cc"],
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
|
33
tensorflow/c/eager/abstract_tensor_handle.cc
Normal file
33
tensorflow/c/eager/abstract_tensor_handle.cc
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
std::string AbstractTensorHandle::DebugString() const {
|
||||
PartialTensorShape shape;
|
||||
Status s = Shape(&shape);
|
||||
std::string shape_string;
|
||||
if (!s.ok()) {
|
||||
shape_string = "<error computing shape>";
|
||||
} else {
|
||||
shape_string = shape.DebugString();
|
||||
}
|
||||
return absl::StrCat("TensorHandle(shape=", shape_string,
|
||||
", dtype=", DataType_Name(DataType()), ")");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -38,6 +38,10 @@ class AbstractTensorHandle : public core::RefCounted {
|
||||
virtual tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const = 0;
|
||||
|
||||
// The default debug string includes a shape and dtype. Implementations are
|
||||
// free to override it with something more informative.
|
||||
virtual std::string DebugString() const;
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
private:
|
||||
|
@ -93,6 +93,7 @@ TEST(CustomDevice, TestTensorHandle) {
|
||||
s = tensor->NumElements(&num_elements);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
EXPECT_EQ(3, num_elements);
|
||||
EXPECT_EQ("TensorHandle(shape=[3], dtype=DT_FLOAT)", tensor->DebugString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user