[tf.data] Improving detection of infinitely repeated datasets in the presence of errors.

PiperOrigin-RevId: 280200721
Change-Id: Icfaffb567b970da140e9b0d3a6c2093452893f01
This commit is contained in:
Jiri Simsa 2019-11-13 08:12:28 -08:00 committed by TensorFlower Gardener
parent d55375021a
commit b95598fac0
5 changed files with 41 additions and 13 deletions

View File

@ -403,6 +403,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
bool* end_of_sequence) {
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
profiler::TraceMeLevel::kInfo);
VLOG(3) << prefix() << " GetNext";
RecordStart(ctx, /*stop_output=*/true);
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
if (s.ok() && !*end_of_sequence) RecordElement(ctx);

View File

@ -839,9 +839,13 @@ class DatasetBaseIterator : public IteratorBase {
explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
params_.dataset->Ref();
VLOG(3) << prefix() << " constructor";
}
~DatasetBaseIterator() override { params_.dataset->Unref(); }
~DatasetBaseIterator() override {
VLOG(3) << prefix() << " destructor";
params_.dataset->Unref();
}
const DataTypeVector& output_dtypes() const override {
return params_.dataset->output_dtypes();

View File

@ -54,6 +54,7 @@ const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
const int64 kMaxEpochsInBuffer = 3;
constexpr char kNumRandomSamples[] = "num_random_samples";
constexpr char kDataProduced[] = "data_produced";
constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
constexpr char kEpoch[] = "epoch";
constexpr char kNumElements[] = "num_elements";
@ -138,9 +139,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
mutex_lock l(mu_);
int64 start_micros = ctx->env()->NowMicros();
int64 num_log_entries = 0;
bool first_call = false;
if (!input_impl_ && epoch_ == 0) {
first_call = true;
TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
ctx, this->prefix(), &input_impl_));
}
@ -158,13 +157,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
&end_of_input_sequence));
if (!end_of_input_sequence) {
first_call = false;
data_produced_ = true;
break;
}
if (first_call && this->dataset()->count_ == -1) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
if (!data_produced_ && this->dataset()->count_ == -1) {
// If we encounter the end of sequence without producing data, we
// terminate the iteration immediately. (Otherwise, this iterator
// would loop infinitely and never produce a value.)
*end_of_sequence = true;
return Status::OK();
@ -289,6 +287,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
}
}
}
if (data_produced_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(this->full_name(kDataProduced), ""));
}
return Status::OK();
}
@ -353,6 +355,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
}
}
}
data_produced_ = reader->Contains(this->full_name(kDataProduced));
return Status::OK();
}
@ -394,6 +397,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
random::SingleSampleAdapter<random::PhiloxRandom> generator_
GUARDED_BY(mu_);
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
bool data_produced_ GUARDED_BY(mu_) = false;
};
const DatasetBase* const input_;

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -308,6 +309,28 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
consume()
self.assertAllEqual(self.evaluate(counter_var), 10)
@combinations.generate(test_base.default_test_combinations())
def testEmptyDataset(self):
dataset = dataset_ops.Dataset.from_tensors(1)
def map_fn(x):
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
return x
dataset = dataset.map(map_fn)
dataset = dataset.cache()
dataset = dataset.shuffle(buffer_size=10).repeat()
get_next = self.getNext(dataset)
# First time around, we get an error for the failed assertion.
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(get_next())
# Second time around, we get an EOF because the cached dataset is empty.
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
if __name__ == "__main__":
test.main()

View File

@ -66,11 +66,7 @@ ASYNC = 1
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
# TODO(b/143164764): Currently _KEEP_ALIVE_SECS is set to a very long time
# (i.e. 30 days) because the server may deadlock when destroying the eager
# context. This may cause memory leak in the headless TPU case, we should change
# it back to 600 once the deadlock is fixed.
_KEEP_ALIVE_SECS = 2592000
_KEEP_ALIVE_SECS = 600
_python_eager_context_create_counter = monitoring.Counter(
"/tensorflow/api/python/eager_context_create_counter",