Move to using 'initializer' from 'initialize' to be more consistent with the tf.data APIs.

PiperOrigin-RevId: 289922514
Change-Id: I2f82cade789d707f287b9915d9856e2683aaa9f6
This commit is contained in:
Anjali Sridhar 2020-01-15 13:14:33 -08:00 committed by TensorFlower Gardener
parent c9e0a34352
commit 019a2531c8
12 changed files with 30 additions and 19 deletions

View File

@ -357,7 +357,7 @@ class CollectiveAllReduceStrategyTestBase(
self.cached_session(config=config,
target=master_target) as sess:
iterator = distribution.make_input_fn_iterator(input_fn)
sess.run(iterator.initialize())
sess.run(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()
@ -375,7 +375,7 @@ class CollectiveAllReduceStrategyTestBase(
# After re-initializing the iterator, should be able to iterate again.
if test_reinitialize:
sess.run(iterator.initialize())
sess.run(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()

View File

@ -128,6 +128,7 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
@ -2285,13 +2286,24 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def get_next(self):
return self._iterator.get_next()
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
"""Initialize underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
if eager_context.executing_eagerly():
self._iterator = self._dataset.make_one_shot_iterator()
return []
else:
return [self._iterator.initializer]
@property
def initializer(self):
"""Returns a list of ops that initialize the iterator."""
return self.initialize()
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):

View File

@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
def get_distributed_dataset(dataset,
@ -348,8 +349,7 @@ class DistributedIterator(object):
class DistributedIteratorV1(DistributedIterator):
"""Input Iterator for tf.data.DatasetV1."""
# TODO(anjalisridhar): Move to using `initializer` instead to be consistent
# with tf.data iterator APIs.
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
"""Initialze underlying iterators.
@ -360,6 +360,7 @@ class DistributedIteratorV1(DistributedIterator):
@property
def initializer(self):
"""Returns a list of ops that initialize the iterator."""
return self.initialize()
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.

View File

@ -171,9 +171,7 @@ class DistributedIteratorTestBase(test.TestCase):
if iteration_type == "get_next":
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
if isinstance(iterator, input_lib.DistributedIteratorV1):
evaluate(control_flow_ops.group(iterator.initialize()))
else:
evaluate(control_flow_ops.group(iterator._initializer))
evaluate(control_flow_ops.group(iterator.initializer))
for expected_value in expected_values:
next_element = iterator.get_next()
@ -192,7 +190,7 @@ class DistributedIteratorTestBase(test.TestCase):
# After re-initializing the iterator, should be able to iterate again.
if isinstance(iterator, input_lib.DistributedIteratorV1):
evaluate(control_flow_ops.group(iterator.initialize()))
evaluate(control_flow_ops.group(iterator.initializer))
else:
evaluate(control_flow_ops.group(iterator._initializer))

View File

@ -101,7 +101,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
metric, args=(iterator.get_next(),)))
batches_per_update = distribution.num_replicas_in_sync
self.evaluate(iterator.initialize())
self.evaluate(iterator.initializer)
self.evaluate([v.initializer for v in metric.variables])
batches_consumed = 0

View File

@ -124,7 +124,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
# replace "distribution.num_replicas_in_sync" with "1".
batches_per_update = distribution.num_replicas_in_sync
self.evaluate(iterator.initialize())
self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
batches_consumed = 0

View File

@ -65,7 +65,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def _get_iterator(self, strategy, input_fn):
iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
self.evaluate(iterator.initialize())
self.evaluate(iterator.initializer)
return iterator
@combinations.generate(

View File

@ -216,7 +216,7 @@ class MirroredVariableCreationTest(test.TestCase):
iterator = distribution.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
self.evaluate(iterator.initialize())
self.evaluate(iterator.initializer)
features = iterator.get_next()
with distribution.scope():

View File

@ -536,7 +536,7 @@ class ParameterServerStrategyTestBase(
self.cached_session(config=config,
target=master_target) as sess:
iterator = distribution.make_input_fn_iterator(input_fn)
sess.run(iterator.initialize())
sess.run(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()
@ -554,7 +554,7 @@ class ParameterServerStrategyTestBase(
# After re-initializing the iterator, should be able to iterate again.
if test_reinitialize:
sess.run(iterator.initialize())
sess.run(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()

View File

@ -55,7 +55,7 @@ class StandardInputStep(Step):
self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
def initialize(self):
return self._iterator.initialize()
return self._iterator.initializer
class StandardSingleLossStep(StandardInputStep):

View File

@ -344,7 +344,7 @@ class DistributionTestBase(test.TestCase):
test_reinitialize=True,
ignore_order=False):
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(iterator.initialize())
evaluate(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()
@ -362,7 +362,7 @@ class DistributionTestBase(test.TestCase):
# After re-initializing the iterator, should be able to iterate again.
if test_reinitialize:
evaluate(iterator.initialize())
evaluate(iterator.initializer)
for expected_value in expected_values:
next_element = iterator.get_next()
@ -414,7 +414,7 @@ class DistributionTestBase(test.TestCase):
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
i = strategy.make_dataset_iterator(ds)
self.evaluate(i.initialize())
self.evaluate(i.initializer)
def run_and_concatenate(strategy, i):
x, y = strategy.experimental_run(lambda z: z, i)

View File

@ -591,7 +591,7 @@ def get_iterator(dataset, distribution_strategy):
def initialize_iterator(iterator, distribution_strategy):
with distribution_strategy.scope():
init_op = control_flow_ops.group(iterator.initialize())
init_op = control_flow_ops.group(iterator.initializer)
if not context.executing_eagerly():
K.get_session((init_op,)).run(init_op)