Move to using 'initializer' from 'initialize' to be more consistent with the tf.data APIs.
PiperOrigin-RevId: 289922514 Change-Id: I2f82cade789d707f287b9915d9856e2683aaa9f6
This commit is contained in:
parent
c9e0a34352
commit
019a2531c8
tensorflow/python
@ -357,7 +357,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
self.cached_session(config=config,
|
||||
target=master_target) as sess:
|
||||
iterator = distribution.make_input_fn_iterator(input_fn)
|
||||
sess.run(iterator.initialize())
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -375,7 +375,7 @@ class CollectiveAllReduceStrategyTestBase(
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if test_reinitialize:
|
||||
sess.run(iterator.initialize())
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
|
@ -128,6 +128,7 @@ from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
@ -2285,13 +2286,24 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def get_next(self):
|
||||
return self._iterator.get_next()
|
||||
|
||||
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||
def initialize(self):
|
||||
"""Initialize underlying iterators.
|
||||
|
||||
Returns:
|
||||
A list of any initializer ops that should be run.
|
||||
"""
|
||||
if eager_context.executing_eagerly():
|
||||
self._iterator = self._dataset.make_one_shot_iterator()
|
||||
return []
|
||||
else:
|
||||
return [self._iterator.initializer]
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
"""Returns a list of ops that initialize the iterator."""
|
||||
return self.initialize()
|
||||
|
||||
# TODO(priyag): Delete this once all strategies use global batch size.
|
||||
@property
|
||||
def _global_batch_size(self):
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
|
||||
|
||||
def get_distributed_dataset(dataset,
|
||||
@ -348,8 +349,7 @@ class DistributedIterator(object):
|
||||
class DistributedIteratorV1(DistributedIterator):
|
||||
"""Input Iterator for tf.data.DatasetV1."""
|
||||
|
||||
# TODO(anjalisridhar): Move to using `initializer` instead to be consistent
|
||||
# with tf.data iterator APIs.
|
||||
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||
def initialize(self):
|
||||
"""Initialze underlying iterators.
|
||||
|
||||
@ -360,6 +360,7 @@ class DistributedIteratorV1(DistributedIterator):
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
"""Returns a list of ops that initialize the iterator."""
|
||||
return self.initialize()
|
||||
|
||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||
|
@ -171,9 +171,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
if iteration_type == "get_next":
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
else:
|
||||
evaluate(control_flow_ops.group(iterator._initializer))
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -192,7 +190,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
else:
|
||||
evaluate(control_flow_ops.group(iterator._initializer))
|
||||
|
||||
|
@ -101,7 +101,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
|
||||
metric, args=(iterator.get_next(),)))
|
||||
batches_per_update = distribution.num_replicas_in_sync
|
||||
|
||||
self.evaluate(iterator.initialize())
|
||||
self.evaluate(iterator.initializer)
|
||||
self.evaluate([v.initializer for v in metric.variables])
|
||||
|
||||
batches_consumed = 0
|
||||
|
@ -124,7 +124,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
# replace "distribution.num_replicas_in_sync" with "1".
|
||||
batches_per_update = distribution.num_replicas_in_sync
|
||||
|
||||
self.evaluate(iterator.initialize())
|
||||
self.evaluate(iterator.initializer)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
|
||||
batches_consumed = 0
|
||||
|
@ -65,7 +65,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _get_iterator(self, strategy, input_fn):
|
||||
iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
|
||||
self.evaluate(iterator.initialize())
|
||||
self.evaluate(iterator.initializer)
|
||||
return iterator
|
||||
|
||||
@combinations.generate(
|
||||
|
@ -216,7 +216,7 @@ class MirroredVariableCreationTest(test.TestCase):
|
||||
|
||||
iterator = distribution.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||
self.evaluate(iterator.initialize())
|
||||
self.evaluate(iterator.initializer)
|
||||
features = iterator.get_next()
|
||||
|
||||
with distribution.scope():
|
||||
|
@ -536,7 +536,7 @@ class ParameterServerStrategyTestBase(
|
||||
self.cached_session(config=config,
|
||||
target=master_target) as sess:
|
||||
iterator = distribution.make_input_fn_iterator(input_fn)
|
||||
sess.run(iterator.initialize())
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -554,7 +554,7 @@ class ParameterServerStrategyTestBase(
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if test_reinitialize:
|
||||
sess.run(iterator.initialize())
|
||||
sess.run(iterator.initializer)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
|
@ -55,7 +55,7 @@ class StandardInputStep(Step):
|
||||
self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||
|
||||
def initialize(self):
|
||||
return self._iterator.initialize()
|
||||
return self._iterator.initializer
|
||||
|
||||
|
||||
class StandardSingleLossStep(StandardInputStep):
|
||||
|
@ -344,7 +344,7 @@ class DistributionTestBase(test.TestCase):
|
||||
test_reinitialize=True,
|
||||
ignore_order=False):
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
evaluate(iterator.initialize())
|
||||
evaluate(iterator.initializer)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -362,7 +362,7 @@ class DistributionTestBase(test.TestCase):
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if test_reinitialize:
|
||||
evaluate(iterator.initialize())
|
||||
evaluate(iterator.initializer)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -414,7 +414,7 @@ class DistributionTestBase(test.TestCase):
|
||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||
i = strategy.make_dataset_iterator(ds)
|
||||
|
||||
self.evaluate(i.initialize())
|
||||
self.evaluate(i.initializer)
|
||||
|
||||
def run_and_concatenate(strategy, i):
|
||||
x, y = strategy.experimental_run(lambda z: z, i)
|
||||
|
@ -591,7 +591,7 @@ def get_iterator(dataset, distribution_strategy):
|
||||
|
||||
def initialize_iterator(iterator, distribution_strategy):
|
||||
with distribution_strategy.scope():
|
||||
init_op = control_flow_ops.group(iterator.initialize())
|
||||
init_op = control_flow_ops.group(iterator.initializer)
|
||||
if not context.executing_eagerly():
|
||||
K.get_session((init_op,)).run(init_op)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user