[tf.data] Record element sizes in IteratorGetNext xprof traces.

PiperOrigin-RevId: 348100471
Change-Id: Ic534eb373936f9d610e36d3ce1028da549450de1
This commit is contained in:
Andrew Audibert 2020-12-17 15:02:13 -08:00 committed by TensorFlower Gardener
parent d0ed3b210c
commit c5c8ad8892
2 changed files with 22 additions and 1 deletions

View File

@ -946,6 +946,17 @@ AsyncOpKernel* IteratorGetNextOp::AsAsync() {
return type_string() == "IteratorGetNextSync" ? nullptr : this;
}
void RecordElementSize(const std::vector<Tensor> element,
profiler::TraceMe* traceme) {
traceme->AppendMetadata([&]() {
int64 element_size = 0;
for (const auto& component : element) {
element_size += component.TotalBytes();
}
return profiler::TraceMeEncode({{"element_size", element_size}});
});
}
Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
profiler::TraceMe traceme(
[&] {
@ -968,6 +979,7 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
}
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
RecordElementSize(components, &traceme);
for (int i = 0; i < components.size(); ++i) {
ctx->set_output(i, components[i]);
}
@ -995,6 +1007,7 @@ Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
if (end_of_sequence) {
return WriteOptionalNoneToOutput(ctx, 0);
} else {
RecordElementSize(components, &traceme);
for (int i = 0; i < components.size(); ++i) {
if (components[i].dtype() != output_types_[i]) {
return errors::InvalidArgument(

View File

@ -299,14 +299,22 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
data::TraceMeMetadata GetTraceMeMetadata() const override {
int64 limit = -1, size = -1;
data::TraceMeMetadata result;
// NOTE: We only set the parallelism value if the lock can be acquired
// right away to avoid introducing tracing overhead.
if (mu_->try_lock()) {
limit = buffer_limit();
size = buffer_.size();
if (!buffer_.empty()) {
std::vector<std::string> shapes(buffer_.front().value.size());
for (const auto& component : buffer_.front().value) {
shapes.push_back(component.shape().DebugString());
}
result.push_back(std::make_pair("next_element_shapes",
absl::StrJoin(shapes, ",")));
}
mu_->unlock();
}
data::TraceMeMetadata result;
result.push_back(std::make_pair(
"buffer_limit",
strings::Printf("%lld", static_cast<long long>(limit))));