[tf.data] Record element sizes in IteratorGetNext xprof traces.
PiperOrigin-RevId: 348100471 Change-Id: Ic534eb373936f9d610e36d3ce1028da549450de1
This commit is contained in:
parent
d0ed3b210c
commit
c5c8ad8892
@ -946,6 +946,17 @@ AsyncOpKernel* IteratorGetNextOp::AsAsync() {
|
||||
return type_string() == "IteratorGetNextSync" ? nullptr : this;
|
||||
}
|
||||
|
||||
void RecordElementSize(const std::vector<Tensor> element,
|
||||
profiler::TraceMe* traceme) {
|
||||
traceme->AppendMetadata([&]() {
|
||||
int64 element_size = 0;
|
||||
for (const auto& component : element) {
|
||||
element_size += component.TotalBytes();
|
||||
}
|
||||
return profiler::TraceMeEncode({{"element_size", element_size}});
|
||||
});
|
||||
}
|
||||
|
||||
Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
profiler::TraceMe traceme(
|
||||
[&] {
|
||||
@ -968,6 +979,7 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
|
||||
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
|
||||
RecordElementSize(components, &traceme);
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
@ -995,6 +1007,7 @@ Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) {
|
||||
if (end_of_sequence) {
|
||||
return WriteOptionalNoneToOutput(ctx, 0);
|
||||
} else {
|
||||
RecordElementSize(components, &traceme);
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
if (components[i].dtype() != output_types_[i]) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -299,14 +299,22 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
data::TraceMeMetadata GetTraceMeMetadata() const override {
|
||||
int64 limit = -1, size = -1;
|
||||
data::TraceMeMetadata result;
|
||||
// NOTE: We only set the parallelism value if the lock can be acquired
|
||||
// right away to avoid introducing tracing overhead.
|
||||
if (mu_->try_lock()) {
|
||||
limit = buffer_limit();
|
||||
size = buffer_.size();
|
||||
if (!buffer_.empty()) {
|
||||
std::vector<std::string> shapes(buffer_.front().value.size());
|
||||
for (const auto& component : buffer_.front().value) {
|
||||
shapes.push_back(component.shape().DebugString());
|
||||
}
|
||||
result.push_back(std::make_pair("next_element_shapes",
|
||||
absl::StrJoin(shapes, ",")));
|
||||
}
|
||||
mu_->unlock();
|
||||
}
|
||||
data::TraceMeMetadata result;
|
||||
result.push_back(std::make_pair(
|
||||
"buffer_limit",
|
||||
strings::Printf("%lld", static_cast<long long>(limit))));
|
||||
|
Loading…
Reference in New Issue
Block a user