[tf.data] Various changes to iterator metrics collection.

This CL:
- adds support for collecting aggregate time tf.data iterator spent actively servicing `GetNext` requests (accounting for concurrent requests)
- adds support for collecting the "lifetime" of tf.data iterator, that is time between receiving the first `GetNext` request and servicing the last `GetNext` request
- removes support for collecting time between subsequent calls to IteratorGetNextOp

PiperOrigin-RevId: 340474712
Change-Id: Icdfd35c46623160e9faacf1af69f897af88049f6
This commit is contained in:
Jiri Simsa 2020-11-03 10:32:27 -08:00 committed by TensorFlower Gardener
parent 2e8c15ee1c
commit b0140088d4
7 changed files with 158 additions and 187 deletions

View File

@ -94,25 +94,23 @@ auto* tf_data_experiment_counter = monitoring::Counter<1>::New(
auto* tf_data_fingerprint_counter = monitoring::Counter<1>::New(
"/tensorflow/data/fingerprint", "tf.data fingerprint", "name");
auto* tf_data_getnext_duration_usecs_histogram = monitoring::Sampler<0>::New(
auto* tf_data_get_next_duration_usecs_histogram = monitoring::Sampler<0>::New(
{"/tensorflow/data/getnext_duration",
"Microseconds spent fetching an element from tf.data Dataset iterator."},
"Microseconds spent fetching an element from tf.data iterator."},
// Power of 2 with bucket count 10 (1024 microseconds) and 1 second.
{monitoring::Buckets::Explicit(
{2., 4., 8., 16., 32., 64., 128., 256., 512., 1024., 1e6})});
auto* tf_data_getnext_time_between_msecs_histogram =
monitoring::Sampler<0>::New(
{"/tensorflow/data/getnext_time_between",
"Milliseconds spent in between calls to tf.data Dataset iterator."},
// A typical training step is in the 200ms to 1 second range.
// Elapsed time less than 25ms are likely due to multiple devices
// calling the iterator's getNext() during the same step. Bucket density
// is highest for small time intervals to more accurately measure fast
// ingest rates. Buckets from 25ms to 10 seconds.
{monitoring::Buckets::Explicit({25., 50., 75., 100., 125., 150., 175.,
200., 225., 250., 300., 350., 400.,
450., 500., 1000., 10000.})});
auto* tf_data_iterator_busy_counter =
monitoring::Counter<0>::New("/tensorflow/data/iterator_busy",
"The time (in microseconds) during which a "
"tf.data iterator was busy processing at "
"least one `GetNext()` request.");
auto* tf_data_iterator_lifetime_counter = monitoring::Counter<0>::New(
"/tensorflow/data/iterator_lifetime",
"The time (in microseconds) between a tf.data iterator receiving the first "
"`GetNext()` request and responding to the last `GetNext()` request.");
auto* tf_data_optimization_counter = monitoring::Counter<1>::New(
"/tensorflow/data/optimization", "tf.data optimization", "name");
@ -199,17 +197,21 @@ void RecordTFDataFingerprint(const string& name) {
}
void RecordTFDataGetNextDuration(uint64 duration_us) {
static auto* tfdata_getnext_duration_cell =
tf_data_getnext_duration_usecs_histogram->GetCell();
tfdata_getnext_duration_cell->Add(duration_us);
static auto* tf_data_get_next_duration_cell =
tf_data_get_next_duration_usecs_histogram->GetCell();
tf_data_get_next_duration_cell->Add(duration_us);
}
void RecordTFDataGetNextTimeBetween(uint64 duration_us) {
static auto* tfdata_getnext_time_between_cell =
tf_data_getnext_time_between_msecs_histogram->GetCell();
// Convert to milliseconds for histogram
const auto duration_ms = duration_us / 1000;
tfdata_getnext_time_between_cell->Add(duration_ms);
void RecordTFDataIteratorBusy(uint64 duration_us) {
static auto* tf_data_iterator_busy_cell =
tf_data_iterator_busy_counter->GetCell();
tf_data_iterator_busy_cell->IncrementBy(duration_us);
}
void RecordTFDataIteratorLifetime(uint64 duration_us) {
static auto* tf_data_iterator_lifetime_cell =
tf_data_iterator_lifetime_counter->GetCell();
tf_data_iterator_lifetime_cell->IncrementBy(duration_us);
}
void RecordTFDataOptimization(const string& name, int64 num_changes) {

View File

@ -59,15 +59,10 @@ void RecordTFDataBytesFetched(int64 num_bytes);
// Records the number of times tf.data experiment is applied to input pipelines.
void RecordTFDataExperiment(const string& name);
// Records the time spent in ItertatorResource::GetNext() in microseconds.
// Records the time (in microseconds) spent in a single invocation of
// `ItertatorResource::GetNext()`.
void RecordTFDataGetNextDuration(uint64 duration_us);
// Records the time spent between IteratorResource::GetNext() calls
// in microseconds. Time is measured from the point of returning data from
// GetNext() to the point of new data being requested.
// This elapsed time corresponds to time spent outside the GetNext() function.
void RecordTFDataGetNextTimeBetween(uint64 duration_us);
// Records the number of times each tf.data fingerprint is used
// to measure duplicate pre-processing.
//
@ -75,6 +70,14 @@ void RecordTFDataGetNextTimeBetween(uint64 duration_us);
// created using GraphHash().
void RecordTFDataFingerprint(const string& name);
// Records the time (in microseconds) during which `IteratorResource` was busy
// processing at least one `GetNext()` request.
void RecordTFDataIteratorBusy(uint64 duration_us);
// Records the time (in microseconds) between `IteratorResource` receiving the
// first `GetNext()` request and responding to the last `GetNext()` request.
void RecordTFDataIteratorLifetime(uint64 duration_us);
// Records the number of independent graph changes resulting from the
// application of a tf.data optimization.
//

View File

@ -455,6 +455,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",

View File

@ -25,11 +25,13 @@ limitations under the License.
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/data/captured_function.h"
@ -64,11 +66,39 @@ const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
const char kOutputShapes[] = "output_shapes";
const char kOutputTypes[] = "output_types";
// Safely subtracts x from y avoiding underflow.
inline uint64 safe_sub(uint64 x, uint64 y) { return x >= y ? x - y : 0; }
} // namespace
/* static */ constexpr const char* const
SerializeIteratorOp::kExternalStatePolicy;
IteratorResource::IteratorResource(
Env* env, const DataTypeVector& output_dtypes,
const std::vector<PartialTensorShape>& output_shapes,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* flr)
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
device_mgr_(std::move(device_mgr)),
iterator_state_(std::make_shared<State>(std::move(flib_def),
std::move(pflr), flr,
/*iterator=*/nullptr)),
output_dtypes_(output_dtypes),
output_shapes_(output_shapes),
// We do not collect iterator resource metrics for non-CPU devices. This
// is a heuristic to avoid collecting metrics for device-side iterators
// created by the multi-device iterator mechanism.
collect_metrics_(flr->device()->device_type() == DEVICE_CPU) {
VLOG(2) << "creating iterator resource";
}
IteratorResource::~IteratorResource() {
VLOG(2) << "destroying iterator resource";
}
Status IteratorResource::GetNext(OpKernelContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@ -77,35 +107,57 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
tf_shared_lock l(mu_);
captured_state = iterator_state_;
}
if (captured_state->iterator) {
IteratorContext::Params params(ctx);
params.flr = captured_state->flr;
params.function_handle_cache = captured_state->function_handle_cache.get();
params.resource_mgr = &captured_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &captured_state->cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
RecordCtx record_ctx = CreateRecordCtx(); // Snapshot state prior to work
// TODO(mkuchnik): Replace wallclock time with steady clock
const uint64 start_time_us = ctx->env()->NowMicros();
RecordGetNextStart(record_ctx, start_time_us);
auto val = captured_state->iterator->GetNext(
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
const uint64 end_time_us = ctx->env()->NowMicros();
RecordGetNextEnd(record_ctx, end_time_us);
metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
return val;
if (!captured_state->iterator) {
return errors::FailedPrecondition(
"GetNext() failed because the iterator has not been initialized. "
"Ensure that you have run the initializer operation for this iterator "
"before getting the next element.");
}
return errors::FailedPrecondition(
"GetNext() failed because the iterator has not been initialized. Ensure "
"that you have run the initializer operation for this iterator before "
"getting the next element.");
IteratorContext::Params params(ctx);
params.flr = captured_state->flr;
params.function_handle_cache = captured_state->function_handle_cache.get();
params.resource_mgr = &captured_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &captured_state->cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
const uint64 start_time_us = ctx->env()->NowMicros();
if (collect_metrics_) {
mutex_lock l(mu_);
if (get_next_end_time_us_ == 0) {
// We initialize `get_next_end_time_us_` to the start time of the first
// request to make it possible to use the delta between
// `get_next_end_time_us_` and subsequent `GetNext()` end time to
// incrementally collect the duration of the iterator's lifetime.
get_next_end_time_us_ = start_time_us;
}
if (num_get_next_calls_ == 0) {
get_next_start_time_us_ = start_time_us;
}
num_get_next_calls_++;
}
auto status = captured_state->iterator->GetNext(
IteratorContext(std::move(params)), out_tensors, end_of_sequence);
if (collect_metrics_) {
const uint64 end_time_us = ctx->env()->NowMicros();
metrics::RecordTFDataGetNextDuration(safe_sub(end_time_us, start_time_us));
metrics::RecordTFDataBytesFetched(GetTotalBytes(*out_tensors));
mutex_lock l(mu_);
metrics::RecordTFDataIteratorLifetime(
safe_sub(end_time_us, get_next_end_time_us_));
get_next_end_time_us_ = std::max(get_next_end_time_us_, end_time_us);
num_get_next_calls_--;
if (num_get_next_calls_ == 0) {
metrics::RecordTFDataIteratorBusy(
safe_sub(get_next_end_time_us_, get_next_start_time_us_));
}
}
return status;
}
Status IteratorResource::Save(SerializationContext* ctx,
@ -209,71 +261,6 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
return Status::OK();
}
IteratorResource::RecordCtx IteratorResource::CreateRecordCtx()
TF_LOCKS_EXCLUDED(mu_) {
IteratorResource::RecordCtx record_ctx;
{
tf_shared_lock l(mu_);
record_ctx.last_get_next_end_time_us =
iterator_state_->last_get_next_end_time_us;
}
return record_ctx;
}
void IteratorResource::RecordGetNextStart(
IteratorResource::RecordCtx& record_ctx, const uint64 start_time_us) {
record_ctx.get_next_start_time_us = start_time_us;
uint64 last_end_time_us = record_ctx.last_get_next_end_time_us;
// Records the total amount of time that has elapsed between GetNext()
// calls. The time between calls is measured from the point of returning
// data from GetNext() to the point of requesting data from GetNext().
// A steady clock is preferable. There are three parts to the algorithm
// under concurrency which maintain the thread local invariant
// last_end_time_us <= start_time_us <= end_time_us and the
// IteratorResource invariant that last_end_time_us is increasing:
// 1) CreateRecordCtx() is called, which copies the
// last_get_next_end_time_us into a thread-local structure
// 2) RecordGetNextStart is called with a clock measured after 1),
// thus ensuring that local start_time_us >= last_get_next_end_time_us
// 3) RecordGetNextEnd is called with a clock measured after 2),
// thus ensuring that local end_time_us >= start_time_us. Additionally,
// this function updates the IteratorResource last_get_next_end_time_us
// with the most recent time. Thus, if two threads call this method,
// only the most recent one is visible in the time.
// It's worth noting that a mutex over all three pieces may be needed for
// strict serialization correctness (i.e., local time may grow stale).
if (last_end_time_us) { // last_end_time_us is initialized at 0
if (start_time_us >= last_end_time_us) {
const uint64 get_next_time_between = start_time_us - last_end_time_us;
metrics::RecordTFDataGetNextTimeBetween(get_next_time_between);
} else {
// Clock went backward (not steady).
metrics::RecordTFDataGetNextTimeBetween(0);
}
}
}
void IteratorResource::RecordGetNextEnd(
const IteratorResource::RecordCtx& record_ctx, const uint64 end_time_us)
TF_LOCKS_EXCLUDED(mu_) {
uint64 start_time_us = record_ctx.get_next_start_time_us;
{
mutex_lock l(mu_);
// Move last_end_time forward if more recent
iterator_state_->last_get_next_end_time_us =
std::max(end_time_us, iterator_state_->last_get_next_end_time_us);
}
DCHECK_NE(start_time_us, 0);
if (end_time_us >= start_time_us) {
const uint64 get_next_duration = end_time_us - start_time_us;
metrics::RecordTFDataGetNextDuration(get_next_duration);
} else {
// Clock went backward (not steady).
metrics::RecordTFDataGetNextDuration(0);
}
}
namespace {
// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
@ -504,8 +491,8 @@ void IteratorHandleOp::Compute(OpKernelContext* context)
this](IteratorResource** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
context->env(), output_dtypes_, output_shapes_,
graph_def_version_, std::move(device_mgr),
std::move(flib_def), std::move(pflr), flr);
std::move(device_mgr), std::move(flib_def), std::move(pflr),
flr);
return Status::OK();
}));
@ -577,8 +564,8 @@ Status AnonymousIteratorHandleOp::CreateResource(
FunctionLibraryRuntime* lib, IteratorResource** resource) {
std::unique_ptr<DeviceMgr> device_mgr(nullptr);
*resource = new IteratorResource(ctx->env(), output_dtypes_, output_shapes_,
graph_def_version_, std::move(device_mgr),
std::move(flib_def), std::move(pflr), lib);
std::move(device_mgr), std::move(flib_def),
std::move(pflr), lib);
return Status::OK();
}
@ -873,7 +860,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new IteratorResource(
ctx->env(), output_dtypes_, output_shapes_,
graph_def_version_, nullptr, std::move(flib_def),
/*device_mgr=*/nullptr, std::move(flib_def),
std::move(pflr), flr);
return Status::OK();
}));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
@ -33,22 +34,12 @@ class IteratorResource : public ResourceBase {
public:
IteratorResource(Env* env, const DataTypeVector& output_dtypes,
const std::vector<PartialTensorShape>& output_shapes,
const int /*unused: graph_def_version*/,
std::unique_ptr<DeviceMgr> device_mgr,
std::unique_ptr<FunctionLibraryDefinition> flib_def,
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
FunctionLibraryRuntime* flr)
: unbounded_thread_pool_(env, "tf_data_iterator_resource"),
device_mgr_(std::move(device_mgr)),
iterator_state_(std::make_shared<State>(std::move(flib_def),
std::move(pflr), flr,
/*iterator=*/nullptr)),
output_dtypes_(output_dtypes),
output_shapes_(output_shapes) {
VLOG(2) << "constructor";
}
FunctionLibraryRuntime* flr);
~IteratorResource() override { VLOG(2) << "destructor"; }
~IteratorResource() override;
// Gets the next output from the iterator managed by this iterator resource.
//
@ -92,8 +83,7 @@ class IteratorResource : public ResourceBase {
flr(flr),
pflr(std::move(pflr)),
function_handle_cache(absl::make_unique<FunctionHandleCache>(flr)),
iterator(std::move(iterator)),
last_get_next_end_time_us(0) {}
iterator(std::move(iterator)) {}
~State() { cancellation_manager.StartCancel(); }
@ -110,35 +100,22 @@ class IteratorResource : public ResourceBase {
ResourceMgr resource_mgr;
CancellationManager cancellation_manager;
std::unique_ptr<DatasetBaseIterator> iterator;
uint64 last_get_next_end_time_us;
};
// For thread-local record-keeping state
struct RecordCtx {
RecordCtx() : get_next_start_time_us(0), last_get_next_end_time_us(0) {}
uint64 get_next_start_time_us;
uint64 last_get_next_end_time_us;
};
// Copies relevant state to the RecordCtx
// Intended to be followed by RecordGetNextStart and RecordGetNextEnd.
// Recorded times must be measured after this call to enforce ordering.
RecordCtx CreateRecordCtx() TF_LOCKS_EXCLUDED(mu_);
// Records that GetNext() has started work.
void RecordGetNextStart(RecordCtx& record_ctx, const uint64 start_time_us);
// Records that GetNext() has ended work.
void RecordGetNextEnd(const RecordCtx& record_ctx, const uint64 end_time_us)
TF_LOCKS_EXCLUDED(mu_);
UnboundedThreadPool unbounded_thread_pool_;
mutex mu_;
// Records the number of currently active `GetNext()` calls.
uint64 num_get_next_calls_ TF_GUARDED_BY(mu_) = 0;
// Records the start time (in microseconds) of the first `GetNext()` call that
// followed the last period of inactivity.
uint64 get_next_start_time_us_ TF_GUARDED_BY(mu_) = 0;
// Records the end time (in microseconds) of the most recent `GetNext()` call.
uint64 get_next_end_time_us_ TF_GUARDED_BY(mu_) = 0;
const std::unique_ptr<DeviceMgr> device_mgr_ TF_GUARDED_BY(mu_);
std::shared_ptr<State> iterator_state_ TF_GUARDED_BY(mu_);
const DataTypeVector output_dtypes_;
const std::vector<PartialTensorShape> output_shapes_;
const bool collect_metrics_;
};
class IteratorHandleOp : public OpKernel {

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.MultiDeviceIterator`."""
"""Tests for the non-public `MultiDeviceIterator` API."""
from __future__ import absolute_import
from __future__ import division
@ -39,16 +39,12 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import test
def skip_v2_test_combinations():
# TODO(b/121264236): Support v2 behavior for these tests.
return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
# TODO(b/121264236): Support v2 behavior for these tests.
class MultiDeviceIteratorTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(
combinations.times(skip_v2_test_combinations(),
combinations.times(test_base.v1_only_combinations(),
combinations.combine(num_inits=[0, 1, 42])))
def testInitOnly(self, num_inits):
dataset = dataset_ops.Dataset.range(10)
@ -60,7 +56,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
for _ in range(num_inits):
self.evaluate(multi_device_iterator.initializer)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testBasic(self):
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
@ -78,7 +74,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testOneOnSameDevice(self):
with ops.device("/cpu:0"):
dataset = dataset_ops.Dataset.range(10)
@ -97,7 +93,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testRepeatDevices(self):
with ops.device("/cpu:0"):
dataset = dataset_ops.Dataset.range(20)
@ -122,7 +118,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_3)
self.evaluate(elem_on_4)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testNotFullyDivisible(self):
dataset = dataset_ops.Dataset.range(9)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
@ -142,7 +138,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testGetNextAsOptional(self):
if context.executing_eagerly():
return
@ -179,7 +175,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(elem_on_2_t)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testUneven(self):
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
@ -199,7 +195,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testMultipleInitializationsGraph(self):
if context.executing_eagerly():
return
@ -223,7 +219,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
elem_on_2]))
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testMultipleInitializationsEager(self):
if not context.executing_eagerly():
return
@ -239,7 +235,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testBasicGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -260,7 +256,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testUnevenGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -283,7 +279,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
self.evaluate(elem_on_1)
self.evaluate(elem_on_2)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testGetNextAsOptionalGpu(self):
if not test_util.is_gpu_available() or context.executing_eagerly():
self.skipTest("No GPU available")
@ -320,7 +316,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(elem_on_2_t)
@combinations.generate(skip_v2_test_combinations())
@combinations.generate(test_base.v1_only_combinations())
def testOptimization(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(testing.assert_next(["MemoryCacheImpl"]))
@ -350,7 +346,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase,
class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
parameterized.TestCase):
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.v2_eager_only_combinations())
def testBasic(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -364,7 +360,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
for i, el in enumerate(mdi):
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.v2_eager_only_combinations())
def testBasicFunction(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -387,7 +383,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
for i in range(10):
self.assertEqual(queue.dequeue().numpy(), i)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.v2_eager_only_combinations())
def testFunctionError(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -421,7 +417,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
self.assertEqual(queue.size().numpy(), 2)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.v2_eager_only_combinations())
def testMultipleInitializations(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@ -436,7 +432,7 @@ class OwnedMultiDeviceIteratorTest(test_base.DatasetTestBase,
for i, el in enumerate(multi_device_iterator):
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
@combinations.generate(test_base.v2_eager_only_combinations())
def testLimitedRetracing(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")

View File

@ -51,6 +51,11 @@ def graph_only_combinations():
return combinations.combine(tf_api_version=[1, 2], mode="graph")
def v1_only_combinations():
"""Returns the default test combinations for v1 only tf.data tests."""
return combinations.combine(tf_api_version=1, mode=["eager", "graph"])
def v2_only_combinations():
"""Returns the default test combinations for v2 only tf.data tests."""
return combinations.combine(tf_api_version=2, mode=["eager", "graph"])