[tf.data] Various changes to iterator metrics collection.
This CL: - adds support for collecting aggregate time tf.data iterator spent actively servicing `GetNext` requests (accounting for concurrent requests) - adds support for collecting the "lifetime" of tf.data iterator, that is time between receiving the first `GetNext` request and servicing the last `GetNext` request - removes support for collecting time between subsequent calls to IteratorGetNextOp PiperOrigin-RevId: 340474712 Change-Id: Icdfd35c46623160e9faacf1af69f897af88049f6
This commit is contained in:
parent
2e8c15ee1c
commit
b0140088d4
@ -94,25 +94,23 @@ auto* tf_data_experiment_counter = monitoring::Counter<1>::New(
|
||||
auto* tf_data_fingerprint_counter = monitoring::Counter<1>::New(
|
||||
"/tensorflow/data/fingerprint", "tf.data fingerprint", "name");
|
||||
|
||||
auto* tf_data_getnext_duration_usecs_histogram = monitoring::Sampler<0>::New(
|
||||
auto* tf_data_get_next_duration_usecs_histogram = monitoring::Sampler<0>::New(
|
||||
{"/tensorflow/data/getnext_duration",
|
||||
"Microseconds spent fetching an element from tf.data Dataset iterator."},
|
||||
"Microseconds spent fetching an element from tf.data iterator."},
|
||||
// Power of 2 with bucket count 10 (1024 microseconds) and 1 second.
|
||||
{monitoring::Buckets::Explicit(
|
||||
{2., 4., 8., 16., 32., 64., 128., 256., 512., 1024., 1e6})});
|
||||
|
||||
auto* tf_data_getnext_time_between_msecs_histogram =
|
||||
monitoring::Sampler<0>::New(
|
||||
{"/tensorflow/data/getnext_time_between",
|
||||
"Milliseconds spent in between calls to tf.data Dataset iterator."},
|
||||
// A typical training step is in the 200ms to 1 second range.
|
||||
// Elapsed time less than 25ms are likely due to multiple devices
|
||||
// calling the iterator's getNext() during the same step. Bucket density
|
||||
// is highest for small time intervals to more accurately measure fast
|
||||
// ingest rates. Buckets from 25ms to 10 seconds.
|
||||
{monitoring::Buckets::Explicit({25., 50., 75., 100., 125., 150., 175.,
|
||||
200., 225., 250., 300., 350., 400.,
|
||||
450., 500., 1000., 10000.})});
|
||||
auto* tf_data_iterator_busy_counter =
|
||||
monitoring::Counter<0>::New("/tensorflow/data/iterator_busy",
|
||||
"The time (in microseconds) during which a "
|
||||
"tf.data iterator was busy processing at "
|
||||
"least one `GetNext()` request.");
|
||||
|
||||
auto* tf_data_iterator_lifetime_counter = monitoring::Counter<0>::New(
|
||||
"/tensorflow/data/iterator_lifetime",
|
||||
"The time (in microseconds) between a tf.data iterator receiving the first "
|
||||
"`GetNext()` request and responding to the last `GetNext()` request.");
|
||||
|
||||
auto* tf_data_optimization_counter = monitoring::Counter<1>::New(
|
||||
"/tensorflow/data/optimization", "tf.data optimization", "name");
|
||||
@ -199,17 +197,21 @@ void RecordTFDataFingerprint(const string& name) {
|
||||
}
|
||||
|
||||
void RecordTFDataGetNextDuration(uint64 duration_us) {
|
||||
static auto* tfdata_getnext_duration_cell =
|
||||
tf_data_getnext_duration_usecs_histogram->GetCell();
|
||||
tfdata_getnext_duration_cell->Add(duration_us);
|
||||
static auto* tf_data_get_next_duration_cell =
|
||||
tf_data_get_next_duration_usecs_histogram->GetCell();
|
||||
tf_data_get_next_duration_cell->Add(duration_us);
|
||||
}
|
||||
|
||||
void RecordTFDataGetNextTimeBetween(uint64 duration_us) {
|
||||
static auto* tfdata_getnext_time_between_cell =
|
||||
tf_data_getnext_time_between_msecs_histogram->GetCell();
|
||||
// Convert to milliseconds for histogram
|
||||
const auto duration_ms = duration_us / 1000;
|
||||
tfdata_getnext_time_between_cell->Add(duration_ms);
|
||||
void RecordTFDataIteratorBusy(uint64 duration_us) {
|
||||
static auto* tf_data_iterator_busy_cell =
|
||||
tf_data_iterator_busy_counter->GetCell();
|
||||
tf_data_iterator_busy_cell->IncrementBy(duration_us);
|
||||
}
|
||||
|
||||
void RecordTFDataIteratorLifetime(uint64 duration_us) {
|
||||
static auto* tf_data_iterator_lifetime_cell =
|
||||
tf_data_iterator_lifetime_counter->GetCell();
|
||||
tf_data_iterator_lifetime_cell->IncrementBy(duration_us);
|
||||
}
|
||||
|
||||
void RecordTFDataOptimization(const string& name, int64 num_changes) {
|
||||
|
@ -59,15 +59,10 @@ void RecordTFDataBytesFetched(int64 num_bytes);
|
||||
// Records the number of times tf.data experiment is applied to input pipelines.
|
||||
void RecordTFDataExperiment(const string& name);
|
||||
|
||||
// Records the time spent in ItertatorResource::GetNext() in microseconds.
|
||||
// Records the time (in microseconds) spent in a single invocation of
|
||||
// `ItertatorResource::GetNext()`.
|
||||
void RecordTFDataGetNextDuration(uint64 duration_us);
|
||||
|
||||
// Records the time spent between IteratorResource::GetNext() calls
|
||||
// in microseconds. Time is measured from the point of returning data from
|
||||
// GetNext() to the point of new data being requested.
|
||||
// This elapsed time corresponds to time spent outside the GetNext() function.
|
||||
void RecordTFDataGetNextTimeBetween(uint64 duration_us);
|
||||
|
||||
// Records the number of times each tf.data fingerprint is used
|
||||
// to measure duplicate pre-processing.
|
||||
//
|
||||
@ -75,6 +70,14 @@ void RecordTFDataGetNextTimeBetween(uint64 duration_us);
|
||||
// created using GraphHash().
|
||||
void RecordTFDataFingerprint(const string& name);
|
||||
|
||||
// Records the time (in microseconds) during which `IteratorResource` was busy
|
||||
// processing at least one `GetNext()` request.
|
||||
void RecordTFDataIteratorBusy(uint64 duration_us);
|
||||
|
||||
// Records the time (in microseconds) between `IteratorResource` receiving the
|
||||
// first `GetNext()` request and responding to the last `GetNext()` request.
|
||||
void RecordTFDataIteratorLifetime(uint64 duration_us);
|
||||
|
||||
// Records the number of independent graph changes resulting from the
|
||||
// application of a tf.data optimization.
|
||||
//
|
||||
|
@ -455,6 +455,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -25,11 +25,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/threadpool_device.h"
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/metrics.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/stats_aggregator.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
@ -64,11 +66,39 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
|
||||
const char kOutputShapes[] = "output_shapes";
|
||||
const char kOutputTypes[] = "output_types";
|
||||
|
||||
// Safely subtracts x from y avoiding underflow.
|
||||
inline uint64 safe_sub(uint64 x, uint64 y) { return x >= y ? x - y : 0; }
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ constexpr const char* const
|
||||
SerializeIteratorOp::kExternalStatePolicy;
|
||||
|
||||
IteratorResource::IteratorResource(
|
||||
Env* env, const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
std::unique_ptr<DeviceMgr> device_mgr,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* flr)
|
||||
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
|
||||
device_mgr_(std::move(device_mgr)),
|
||||
iterator_state_(std::make_shared<State>(std::move(flib_def),
|
||||
std::move(pflr), flr,
|
||||
/*iterator=*/nullptr)),
|
||||
output_dtypes_(output_dtypes),
|
||||
output_shapes_(output_shapes),
|
||||
// We do not collect iterator resource metrics for non-CPU devices. This
|
||||
// is a heuristic to avoid collecting metrics for device-side iterators
|
||||
// created by the multi-device iterator mechanism.
|
||||
collect_metrics_(flr->device()->device_type() == DEVICE_CPU) {
|
||||
VLOG(2) << "creating iterator resource";
|
||||
}
|
||||
|
||||
IteratorResource::~IteratorResource() {
|
||||
VLOG(2) << "destroying iterator resource";
|
||||
}
|
||||
|
||||
Status IteratorResource::GetNext(OpKernelContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
@ -77,35 +107,57 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
|
||||
tf_shared_lock l(mu_);
|
||||
captured_state = iterator_state_;
|
||||
}
|
||||
if (captured_state->iterator) {
|
||||
IteratorContext::Params params(ctx);
|
||||
params.flr = captured_state->flr;
|
||||
params.function_handle_cache = captured_state->function_handle_cache.get();
|
||||
params.resource_mgr = &captured_state->resource_mgr;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = &captured_state->cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
RecordCtx record_ctx = CreateRecordCtx(); // Snapshot state prior to work
|
||||
// TODO(mkuchnik): Replace wallclock time with steady clock
|
||||
const uint64 start_time_us = ctx->env()->NowMicros();
|
||||
RecordGetNextStart(record_ctx, start_time_us);
|
||||
auto val = captured_state->iterator->GetNext(
|
||||
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
|
||||
const uint64 end_time_us = ctx->env()->NowMicros();
|
||||
RecordGetNextEnd(record_ctx, end_time_us);
|
||||
metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
|
||||
return val;
|
||||
if (!captured_state->iterator) {
|
||||
return errors::FailedPrecondition(
|
||||
"GetNext() failed because the iterator has not been initialized. "
|
||||
"Ensure that you have run the initializer operation for this iterator "
|
||||
"before getting the next element.");
|
||||
}
|
||||
return errors::FailedPrecondition(
|
||||
"GetNext() failed because the iterator has not been initialized. Ensure "
|
||||
"that you have run the initializer operation for this iterator before "
|
||||
"getting the next element.");
|
||||
IteratorContext::Params params(ctx);
|
||||
params.flr = captured_state->flr;
|
||||
params.function_handle_cache = captured_state->function_handle_cache.get();
|
||||
params.resource_mgr = &captured_state->resource_mgr;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = &captured_state->cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
const uint64 start_time_us = ctx->env()->NowMicros();
|
||||
if (collect_metrics_) {
|
||||
mutex_lock l(mu_);
|
||||
if (get_next_end_time_us_ == 0) {
|
||||
// We initialize `get_next_end_time_us_` to the start time of the first
|
||||
// request to make it possible to use the delta between
|
||||
// `get_next_end_time_us_` and subsequent `GetNext()` end time to
|
||||
// incrementally collect the duration of the iterator's lifetime.
|
||||
get_next_end_time_us_ = start_time_us;
|
||||
}
|
||||
if (num_get_next_calls_ == 0) {
|
||||
get_next_start_time_us_ = start_time_us;
|
||||
}
|
||||
num_get_next_calls_++;
|
||||
}
|
||||
auto status = captured_state->iterator->GetNext(
|
||||
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
|
||||
if (collect_metrics_) {
|
||||
const uint64 end_time_us = ctx->env()->NowMicros();
|
||||
metrics::RecordTFDataGetNextDuration(safe_sub(end_time_us, start_time_us));
|
||||
metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
|
||||
mutex_lock l(mu_);
|
||||
metrics::RecordTFDataIteratorLifetime(
|
||||
safe_sub(end_time_us, get_next_end_time_us_));
|
||||
get_next_end_time_us_ = std::max(get_next_end_time_us_, end_time_us);
|
||||
num_get_next_calls_--;
|
||||
if (num_get_next_calls_ == 0) {
|
||||
metrics::RecordTFDataIteratorBusy(
|
||||
safe_sub(get_next_end_time_us_, get_next_start_time_us_));
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status IteratorResource::Save(SerializationContext* ctx,
|
||||
@ -209,71 +261,6 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
IteratorResource::RecordCtx IteratorResource::CreateRecordCtx()
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
IteratorResource::RecordCtx record_ctx;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
record_ctx.last_get_next_end_time_us =
|
||||
iterator_state_->last_get_next_end_time_us;
|
||||
}
|
||||
return record_ctx;
|
||||
}
|
||||
|
||||
void IteratorResource::RecordGetNextStart(
|
||||
IteratorResource::RecordCtx& record_ctx, const uint64 start_time_us) {
|
||||
record_ctx.get_next_start_time_us = start_time_us;
|
||||
uint64 last_end_time_us = record_ctx.last_get_next_end_time_us;
|
||||
|
||||
// Records the total amount of time that has elapsed between GetNext()
|
||||
// calls. The time between calls is measured from the point of returning
|
||||
// data from GetNext() to the point of requesting data from GetNext().
|
||||
// A steady clock is preferable. There are three parts to the algorithm
|
||||
// under concurrency which maintain the thread local invariant
|
||||
// last_end_time_us <= start_time_us <= end_time_us and the
|
||||
// IteratorResource invariant that last_end_time_us is increasing:
|
||||
// 1) CreateRecordCtx() is called, which copies the
|
||||
// last_get_next_end_time_us into a thread-local structure
|
||||
// 2) RecordGetNextStart is called with a clock measured after 1),
|
||||
// thus ensuring that local start_time_us >= last_get_next_end_time_us
|
||||
// 3) RecordGetNextEnd is called with a clock measured after 2),
|
||||
// thus ensuring that local end_time_us >= start_time_us. Additionally,
|
||||
// this function updates the IteratorResource last_get_next_end_time_us
|
||||
// with the most recent time. Thus, if two threads call this method,
|
||||
// only the most recent one is visible in the time.
|
||||
// It's worth noting that a mutex over all three pieces may be needed for
|
||||
// strict serialization correctness (i.e., local time may grow stale).
|
||||
if (last_end_time_us) { // last_end_time_us is initialized at 0
|
||||
if (start_time_us >= last_end_time_us) {
|
||||
const uint64 get_next_time_between = start_time_us - last_end_time_us;
|
||||
metrics::RecordTFDataGetNextTimeBetween(get_next_time_between);
|
||||
} else {
|
||||
// Clock went backward (not steady).
|
||||
metrics::RecordTFDataGetNextTimeBetween(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void IteratorResource::RecordGetNextEnd(
|
||||
const IteratorResource::RecordCtx& record_ctx, const uint64 end_time_us)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
uint64 start_time_us = record_ctx.get_next_start_time_us;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Move last_end_time forward if more recent
|
||||
iterator_state_->last_get_next_end_time_us =
|
||||
std::max(end_time_us, iterator_state_->last_get_next_end_time_us);
|
||||
}
|
||||
DCHECK_NE(start_time_us, 0);
|
||||
if (end_time_us >= start_time_us) {
|
||||
const uint64 get_next_duration = end_time_us - start_time_us;
|
||||
metrics::RecordTFDataGetNextDuration(get_next_duration);
|
||||
} else {
|
||||
// Clock went backward (not steady).
|
||||
metrics::RecordTFDataGetNextDuration(0);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
|
||||
@ -504,8 +491,8 @@ void IteratorHandleOp::Compute(OpKernelContext* context)
|
||||
this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new IteratorResource(
|
||||
context->env(), output_dtypes_, output_shapes_,
|
||||
graph_def_version_, std::move(device_mgr),
|
||||
std::move(flib_def), std::move(pflr), flr);
|
||||
std::move(device_mgr), std::move(flib_def), std::move(pflr),
|
||||
flr);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
@ -577,8 +564,8 @@ Status AnonymousIteratorHandleOp::CreateResource(
|
||||
FunctionLibraryRuntime* lib, IteratorResource** resource) {
|
||||
std::unique_ptr<DeviceMgr> device_mgr(nullptr);
|
||||
*resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
|
||||
graph_def_version_, std::move(device_mgr),
|
||||
std::move(flib_def), std::move(pflr), lib);
|
||||
std::move(device_mgr), std::move(flib_def),
|
||||
std::move(pflr), lib);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -873,7 +860,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
*ret = new IteratorResource(
|
||||
ctx->env(), output_dtypes_, output_shapes_,
|
||||
graph_def_version_, nullptr, std::move(flib_def),
|
||||
/*device_mgr=*/nullptr, std::move(flib_def),
|
||||
std::move(pflr), flr);
|
||||
return Status::OK();
|
||||
}));
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/framework/metrics.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -33,22 +34,12 @@ class IteratorResource : public ResourceBase {
|
||||
public:
|
||||
IteratorResource(Env* env, const DataTypeVector& output_dtypes,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
const int /*unused: graph_def_version*/,
|
||||
std::unique_ptr<DeviceMgr> device_mgr,
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def,
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
|
||||
FunctionLibraryRuntime* flr)
|
||||
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
|
||||
device_mgr_(std::move(device_mgr)),
|
||||
iterator_state_(std::make_shared<State>(std::move(flib_def),
|
||||
std::move(pflr), flr,
|
||||
/*iterator=*/nullptr)),
|
||||
output_dtypes_(output_dtypes),
|
||||
output_shapes_(output_shapes) {
|
||||
VLOG(2) << "constructor";
|
||||
}
|
||||
FunctionLibraryRuntime* flr);
|
||||
|
||||
~IteratorResource() override { VLOG(2) << "destructor"; }
|
||||
~IteratorResource() override;
|
||||
|
||||
// Gets the next output from the iterator managed by this iterator resource.
|
||||
//
|
||||
@ -92,8 +83,7 @@ class IteratorResource : public ResourceBase {
|
||||
flr(flr),
|
||||
pflr(std::move(pflr)),
|
||||
function_handle_cache(absl::make_unique<FunctionHandleCache>(flr)),
|
||||
iterator(std::move(iterator)),
|
||||
last_get_next_end_time_us(0) {}
|
||||
iterator(std::move(iterator)) {}
|
||||
|
||||
~State() { cancellation_manager.StartCancel(); }
|
||||
|
||||
@ -110,35 +100,22 @@ class IteratorResource : public ResourceBase {
|
||||
ResourceMgr resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
std::unique_ptr<DatasetBaseIterator> iterator;
|
||||
uint64 last_get_next_end_time_us;
|
||||
};
|
||||
|
||||
// For thread-local record-keeping state
|
||||
struct RecordCtx {
|
||||
RecordCtx() : get_next_start_time_us(0), last_get_next_end_time_us(0) {}
|
||||
|
||||
uint64 get_next_start_time_us;
|
||||
uint64 last_get_next_end_time_us;
|
||||
};
|
||||
|
||||
// Copies relevant state to the RecordCtx
|
||||
// Intended to be followed by RecordGetNextStart and RecordGetNextEnd.
|
||||
// Recorded times must be measured after this call to enforce ordering.
|
||||
RecordCtx CreateRecordCtx() TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Records that GetNext() has started work.
|
||||
void RecordGetNextStart(RecordCtx& record_ctx, const uint64 start_time_us);
|
||||
|
||||
// Records that GetNext() has ended work.
|
||||
void RecordGetNextEnd(const RecordCtx& record_ctx, const uint64 end_time_us)
|
||||
TF_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
UnboundedThreadPool unbounded_thread_pool_;
|
||||
mutex mu_;
|
||||
// Records the number of currently active `GetNext()` calls.
|
||||
uint64 num_get_next_calls_ TF_GUARDED_BY(mu_) = 0;
|
||||
// Records the start time (in microseconds) of the first `GetNext()` call that
|
||||
// followed the last period of inactivity.
|
||||
uint64 get_next_start_time_us_ TF_GUARDED_BY(mu_) = 0;
|
||||
// Records the end time (in microseconds) of the most recent `GetNext()` call.
|
||||
uint64 get_next_end_time_us_ TF_GUARDED_BY(mu_) = 0;
|
||||
const std::unique_ptr<DeviceMgr> device_mgr_ TF_GUARDED_BY(mu_);
|
||||
std::shared_ptr<State> iterator_state_ TF_GUARDED_BY(mu_);
|
||||
const DataTypeVector output_dtypes_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
const bool collect_metrics_;
|
||||
};
|
||||
|
||||
class IteratorHandleOp : public OpKernel {
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for `tf.data.MultiDeviceIterator`."""
|
||||
"""Tests for the non-public `MultiDeviceIterator` API."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -39,16 +39,12 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def skip_v2_test_combinations():
|
||||
# TODO(b/121264236): Support v2 behavior for these tests.
|
||||
return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
|
||||
|
||||
|
||||
# TODO(b/121264236): Support v2 behavior for these tests.
|
||||
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(skip_v2_test_combinations(),
|
||||
combinations.times(test_base.v1_only_combinations(),
|
||||
combinations.combine(num_inits=[0, 1, 42])))
|
||||
def testInitOnly(self, num_inits):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -60,7 +56,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
for _ in range(num_inits):
|
||||
self.evaluate(multi_device_iterator.initializer)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testBasic(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
@ -78,7 +74,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testOneOnSameDevice(self):
|
||||
with ops.device("/cpu:0"):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
@ -97,7 +93,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testRepeatDevices(self):
|
||||
with ops.device("/cpu:0"):
|
||||
dataset = dataset_ops.Dataset.range(20)
|
||||
@ -122,7 +118,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_3)
|
||||
self.evaluate(elem_on_4)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testNotFullyDivisible(self):
|
||||
dataset = dataset_ops.Dataset.range(9)
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
@ -142,7 +138,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testGetNextAsOptional(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
@ -179,7 +175,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(elem_on_2_t)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testUneven(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
@ -199,7 +195,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testMultipleInitializationsGraph(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
@ -223,7 +219,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
|
||||
elem_on_2]))
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testMultipleInitializationsEager(self):
|
||||
if not context.executing_eagerly():
|
||||
return
|
||||
@ -239,7 +235,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||
self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testBasicGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -260,7 +256,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testUnevenGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -283,7 +279,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
self.evaluate(elem_on_1)
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testGetNextAsOptionalGpu(self):
|
||||
if not test_util.is_gpu_available() or context.executing_eagerly():
|
||||
self.skipTest("No GPU available")
|
||||
@ -320,7 +316,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(elem_on_2_t)
|
||||
|
||||
@combinations.generate(skip_v2_test_combinations())
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testOptimization(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
|
||||
@ -350,7 +346,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.v2_eager_only_combinations())
|
||||
def testBasic(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -364,7 +360,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
for i, el in enumerate(mdi):
|
||||
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.v2_eager_only_combinations())
|
||||
def testBasicFunction(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -387,7 +383,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
for i in range(10):
|
||||
self.assertEqual(queue.dequeue().numpy(), i)
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.v2_eager_only_combinations())
|
||||
def testFunctionError(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -421,7 +417,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
|
||||
self.assertEqual(queue.size().numpy(), 2)
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.v2_eager_only_combinations())
|
||||
def testMultipleInitializations(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
@ -436,7 +432,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
|
||||
for i, el in enumerate(multi_device_iterator):
|
||||
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
@combinations.generate(test_base.v2_eager_only_combinations())
|
||||
def testLimitedRetracing(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
@ -51,6 +51,11 @@ def graph_only_combinations():
|
||||
return combinations.combine(tf_api_version=[1, 2], mode="graph")
|
||||
|
||||
|
||||
def v1_only_combinations():
|
||||
"""Returns the default test combinations for v1 only tf.data tests."""
|
||||
return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
|
||||
|
||||
|
||||
def v2_only_combinations():
|
||||
"""Returns the default test combinations for v2 only tf.data tests."""
|
||||
return combinations.combine(tf_api_version=2, mode=["eager", "graph"])
|
||||
|
Loading…
Reference in New Issue
Block a user