[tf.data] Improving detection of infinitely repeated datasets in the presence of errors.
PiperOrigin-RevId: 280200721 Change-Id: Icfaffb567b970da140e9b0d3a6c2093452893f01
This commit is contained in:
parent
d55375021a
commit
b95598fac0
@ -403,6 +403,7 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx,
|
|||||||
bool* end_of_sequence) {
|
bool* end_of_sequence) {
|
||||||
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
|
profiler::TraceMe activity([&] { return BuildTraceMeName(); },
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
|
VLOG(3) << prefix() << " GetNext";
|
||||||
RecordStart(ctx, /*stop_output=*/true);
|
RecordStart(ctx, /*stop_output=*/true);
|
||||||
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
|
||||||
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
|
if (s.ok() && !*end_of_sequence) RecordElement(ctx);
|
||||||
|
@ -839,9 +839,13 @@ class DatasetBaseIterator : public IteratorBase {
|
|||||||
|
|
||||||
explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
|
explicit DatasetBaseIterator(const BaseParams& params) : params_(params) {
|
||||||
params_.dataset->Ref();
|
params_.dataset->Ref();
|
||||||
|
VLOG(3) << prefix() << " constructor";
|
||||||
}
|
}
|
||||||
|
|
||||||
~DatasetBaseIterator() override { params_.dataset->Unref(); }
|
~DatasetBaseIterator() override {
|
||||||
|
VLOG(3) << prefix() << " destructor";
|
||||||
|
params_.dataset->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
const DataTypeVector& output_dtypes() const override {
|
||||||
return params_.dataset->output_dtypes();
|
return params_.dataset->output_dtypes();
|
||||||
|
@ -54,6 +54,7 @@ const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
|
|||||||
const int64 kMaxEpochsInBuffer = 3;
|
const int64 kMaxEpochsInBuffer = 3;
|
||||||
|
|
||||||
constexpr char kNumRandomSamples[] = "num_random_samples";
|
constexpr char kNumRandomSamples[] = "num_random_samples";
|
||||||
|
constexpr char kDataProduced[] = "data_produced";
|
||||||
constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
|
constexpr char kEndOfInputSequence[] = "end_of_input_sequence";
|
||||||
constexpr char kEpoch[] = "epoch";
|
constexpr char kEpoch[] = "epoch";
|
||||||
constexpr char kNumElements[] = "num_elements";
|
constexpr char kNumElements[] = "num_elements";
|
||||||
@ -138,9 +139,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
int64 start_micros = ctx->env()->NowMicros();
|
int64 start_micros = ctx->env()->NowMicros();
|
||||||
int64 num_log_entries = 0;
|
int64 num_log_entries = 0;
|
||||||
bool first_call = false;
|
|
||||||
if (!input_impl_ && epoch_ == 0) {
|
if (!input_impl_ && epoch_ == 0) {
|
||||||
first_call = true;
|
|
||||||
TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
|
TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
|
||||||
ctx, this->prefix(), &input_impl_));
|
ctx, this->prefix(), &input_impl_));
|
||||||
}
|
}
|
||||||
@ -158,13 +157,12 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
|
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
|
||||||
&end_of_input_sequence));
|
&end_of_input_sequence));
|
||||||
if (!end_of_input_sequence) {
|
if (!end_of_input_sequence) {
|
||||||
first_call = false;
|
data_produced_ = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (first_call && this->dataset()->count_ == -1) {
|
if (!data_produced_ && this->dataset()->count_ == -1) {
|
||||||
// If the first call to GetNext() fails because the end
|
// If we encounter the end of sequence without producing data, we
|
||||||
// of sequence has been reached, we terminate the
|
// terminate the iteration immediately. (Otherwise, this iterator
|
||||||
// iteration immediately. (Otherwise, this iterator
|
|
||||||
// would loop infinitely and never produce a value.)
|
// would loop infinitely and never produce a value.)
|
||||||
*end_of_sequence = true;
|
*end_of_sequence = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -289,6 +287,10 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (data_produced_) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
writer->WriteScalar(this->full_name(kDataProduced), ""));
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -353,6 +355,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
data_produced_ = reader->Contains(this->full_name(kDataProduced));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -394,6 +397,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
|||||||
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
random::SingleSampleAdapter<random::PhiloxRandom> generator_
|
||||||
GUARDED_BY(mu_);
|
GUARDED_BY(mu_);
|
||||||
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
int64 num_random_samples_ GUARDED_BY(mu_) = 0;
|
||||||
|
bool data_produced_ GUARDED_BY(mu_) = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
const DatasetBase* const input_;
|
const DatasetBase* const input_;
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -308,6 +309,28 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
consume()
|
consume()
|
||||||
self.assertAllEqual(self.evaluate(counter_var), 10)
|
self.assertAllEqual(self.evaluate(counter_var), 10)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testEmptyDataset(self):
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(1)
|
||||||
|
|
||||||
|
def map_fn(x):
|
||||||
|
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
|
||||||
|
return x
|
||||||
|
|
||||||
|
dataset = dataset.map(map_fn)
|
||||||
|
dataset = dataset.cache()
|
||||||
|
dataset = dataset.shuffle(buffer_size=10).repeat()
|
||||||
|
|
||||||
|
get_next = self.getNext(dataset)
|
||||||
|
|
||||||
|
# First time around, we get an error for the failed assertion.
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
# Second time around, we get an EOF because the cached dataset is empty.
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
self.evaluate(get_next())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -66,11 +66,7 @@ ASYNC = 1
|
|||||||
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
|
MIRRORING_NONE = pywrap_tensorflow.TFE_MIRRORING_NONE
|
||||||
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
|
MIRRORING_ALL = pywrap_tensorflow.TFE_MIRRORING_ALL
|
||||||
|
|
||||||
# TODO(b/143164764): Currently _KEEP_ALIVE_SECS is set to a very long time
|
_KEEP_ALIVE_SECS = 600
|
||||||
# (i.e. 30 days) because the server may deadlock when destroying the eager
|
|
||||||
# context. This may cause memory leak in the headless TPU case, we should change
|
|
||||||
# it back to 600 once the deadlock is fixed.
|
|
||||||
_KEEP_ALIVE_SECS = 2592000
|
|
||||||
|
|
||||||
_python_eager_context_create_counter = monitoring.Counter(
|
_python_eager_context_create_counter = monitoring.Counter(
|
||||||
"/tensorflow/api/python/eager_context_create_counter",
|
"/tensorflow/api/python/eager_context_create_counter",
|
||||||
|
Loading…
Reference in New Issue
Block a user