diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index fc6f8fdbb90..261d9302695 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -403,6 +403,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, bool* end_of_sequence) { profiler::TraceMe activity([&] { return BuildTraceMeName(); }, profiler::TraceMeLevel::kInfo); + VLOG(3) << prefix() << " GetNext"; RecordStart(ctx, /*stop_output=*/true); Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); if (s.ok() && !*end_of_sequence) RecordElement(ctx); diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 675f1b6918a..3663a26fbf4 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -839,9 +839,13 @@ class DatasetBaseIterator : public IteratorBase { explicit DatasetBaseIterator(const BaseParams& params) : params_(params) { params_.dataset->Ref(); + VLOG(3) << prefix() << " constructor"; } - ~DatasetBaseIterator() override { params_.dataset->Unref(); } + ~DatasetBaseIterator() override { + VLOG(3) << prefix() << " destructor"; + params_.dataset->Unref(); + } const DataTypeVector& output_dtypes() const override { return params_.dataset->output_dtypes(); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 6f3b939bac5..674467abedf 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -54,6 +54,7 @@ const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. const int64 kMaxEpochsInBuffer = 3; constexpr char kNumRandomSamples[] = "num_random_samples"; +constexpr char kDataProduced[] = "data_produced"; constexpr char kEndOfInputSequence[] = "end_of_input_sequence"; constexpr char kEpoch[] = "epoch"; constexpr char kNumElements[] = "num_elements"; @@ -138,9 +139,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { mutex_lock l(mu_); int64 start_micros = ctx->env()->NowMicros(); int64 num_log_entries = 0; - bool first_call = false; if (!input_impl_ && epoch_ == 0) { - first_call = true; TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator( ctx, this->prefix(), &input_impl_)); } @@ -158,13 +157,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element, &end_of_input_sequence)); if (!end_of_input_sequence) { - first_call = false; + data_produced_ = true; break; } - if (first_call && this->dataset()->count_ == -1) { - // If the first call to GetNext() fails because the end - // of sequence has been reached, we terminate the - // iteration immediately. (Otherwise, this iterator + if (!data_produced_ && this->dataset()->count_ == -1) { + // If we encounter the end of sequence without producing data, we + // terminate the iteration immediately. (Otherwise, this iterator // would loop infinitely and never produce a value.) *end_of_sequence = true; return Status::OK(); @@ -289,6 +287,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { } } } + if (data_produced_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(this->full_name(kDataProduced), "")); + } return Status::OK(); } @@ -353,6 +355,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { } } } + data_produced_ = reader->Contains(this->full_name(kDataProduced)); return Status::OK(); } @@ -394,6 +397,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { random::SingleSampleAdapter generator_ GUARDED_BY(mu_); int64 num_random_samples_ GUARDED_BY(mu_) = 0; + bool data_produced_ GUARDED_BY(mu_) = false; }; const DatasetBase* const input_; diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py index b2d2d23a8fa..7f801e1b5f4 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -308,6 +309,28 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): consume() self.assertAllEqual(self.evaluate(counter_var), 10) + @combinations.generate(test_base.default_test_combinations()) + def testEmptyDataset(self): + dataset = dataset_ops.Dataset.from_tensors(1) + + def map_fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return x + + dataset = dataset.map(map_fn) + dataset = dataset.cache() + dataset = dataset.shuffle(buffer_size=10).repeat() + + get_next = self.getNext(dataset) + + # First time around, we get an error for the failed assertion. + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(get_next()) + + # Second time around, we get an EOF because the cached dataset is empty. + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(get_next()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 98e580b1dfb..2757cc4667d 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -66,11 +66,7 @@ ASYNC = 1 MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL -# TODO(b/143164764): Currently _KEEP_ALIVE_SECS is set to a very long time -# (i.e. 30 days) because the server may deadlock when destroying the eager -# context. This may cause memory leak in the headless TPU case, we should change -# it back to 600 once the deadlock is fixed. -_KEEP_ALIVE_SECS = 2592000 +_KEEP_ALIVE_SECS = 600 _python_eager_context_create_counter = monitoring.Counter( "/tensorflow/api/python/eager_context_create_counter",