[tf.data] Improving detection of infinitely repeated datasets in the presence of errors.
PiperOrigin-RevId: 280200721 Change-Id: Icfaffb567b970da140e9b0d3a6c2093452893f01
This commit is contained in:
parent
d55375021a
commit
b95598fac0
@ -403,6 +403,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
|
||||
bool* end_of_sequence) {
|
||||
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
VLOG(3) << prefix() << " GetNext";
|
||||
RecordStart(ctx, /*stop_output=*/true);
|
||||
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
|
||||
|
@ -839,9 +839,13 @@ class DatasetBaseIterator : public IteratorBase {
|
||||
|
||||
explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
|
||||
params_.dataset->Ref();
|
||||
VLOG(3) << prefix() << " constructor";
|
||||
}
|
||||
|
||||
~DatasetBaseIterator() override { params_.dataset->Unref(); }
|
||||
~DatasetBaseIterator() override {
|
||||
VLOG(3) << prefix() << " destructor";
|
||||
params_.dataset->Unref();
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return params_.dataset->output_dtypes();
|
||||
|
@ -54,6 +54,7 @@ const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
|
||||
const int64 kMaxEpochsInBuffer = 3;
|
||||
|
||||
constexpr char kNumRandomSamples[] = "num_random_samples";
|
||||
constexpr char kDataProduced[] = "data_produced";
|
||||
constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
|
||||
constexpr char kEpoch[] = "epoch";
|
||||
constexpr char kNumElements[] = "num_elements";
|
||||
@ -138,9 +139,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
mutex_lock l(mu_);
|
||||
int64 start_micros = ctx->env()->NowMicros();
|
||||
int64 num_log_entries = 0;
|
||||
bool first_call = false;
|
||||
if (!input_impl_ && epoch_ == 0) {
|
||||
first_call = true;
|
||||
TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
|
||||
ctx, this->prefix(), &input_impl_));
|
||||
}
|
||||
@ -158,13 +157,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
|
||||
&end_of_input_sequence));
|
||||
if (!end_of_input_sequence) {
|
||||
first_call = false;
|
||||
data_produced_ = true;
|
||||
break;
|
||||
}
|
||||
if (first_call && this->dataset()->count_ == -1) {
|
||||
// If the first call to GetNext() fails because the end
|
||||
// of sequence has been reached, we terminate the
|
||||
// iteration immediately. (Otherwise, this iterator
|
||||
if (!data_produced_ && this->dataset()->count_ == -1) {
|
||||
// If we encounter the end of sequence without producing data, we
|
||||
// terminate the iteration immediately. (Otherwise, this iterator
|
||||
// would loop infinitely and never produce a value.)
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
@ -289,6 +287,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (data_produced_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(this->full_name(kDataProduced), ""));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -353,6 +355,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
}
|
||||
}
|
||||
}
|
||||
data_produced_ = reader->Contains(this->full_name(kDataProduced));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -394,6 +397,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
||||
GUARDED_BY(mu_);
|
||||
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
||||
bool data_produced_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -308,6 +309,28 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
consume()
|
||||
self.assertAllEqual(self.evaluate(counter_var), 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testEmptyDataset(self):
|
||||
dataset = dataset_ops.Dataset.from_tensors(1)
|
||||
|
||||
def map_fn(x):
|
||||
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
|
||||
return x
|
||||
|
||||
dataset = dataset.map(map_fn)
|
||||
dataset = dataset.cache()
|
||||
dataset = dataset.shuffle(buffer_size=10).repeat()
|
||||
|
||||
get_next = self.getNext(dataset)
|
||||
|
||||
# First time around, we get an error for the failed assertion.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
# Second time around, we get an EOF because the cached dataset is empty.
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -66,11 +66,7 @@ ASYNC = 1
|
||||
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
|
||||
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
|
||||
|
||||
# TODO(b/143164764): Currently _KEEP_ALIVE_SECS is set to a very long time
|
||||
# (i.e. 30 days) because the server may deadlock when destroying the eager
|
||||
# context. This may cause memory leak in the headless TPU case, we should change
|
||||
# it back to 600 once the deadlock is fixed.
|
||||
_KEEP_ALIVE_SECS = 2592000
|
||||
_KEEP_ALIVE_SECS = 600
|
||||
|
||||
_python_eager_context_create_counter = monitoring.Counter(
|
||||
"/tensorflow/api/python/eager_context_create_counter",
|
||||
|
Loading…
Reference in New Issue
Block a user