Move to using 'initializer' from 'initialize' to be more consistent with the tf.data APIs.
PiperOrigin-RevId: 289922514 Change-Id: I2f82cade789d707f287b9915d9856e2683aaa9f6
This commit is contained in:
parent
c9e0a34352
commit
019a2531c8
@ -357,7 +357,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
self.cached_session(config=config,
|
self.cached_session(config=config,
|
||||||
target=master_target) as sess:
|
target=master_target) as sess:
|
||||||
iterator = distribution.make_input_fn_iterator(input_fn)
|
iterator = distribution.make_input_fn_iterator(input_fn)
|
||||||
sess.run(iterator.initialize())
|
sess.run(iterator.initializer)
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
@ -375,7 +375,7 @@ class CollectiveAllReduceStrategyTestBase(
|
|||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
if test_reinitialize:
|
if test_reinitialize:
|
||||||
sess.run(iterator.initialize())
|
sess.run(iterator.initializer)
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
@ -128,6 +128,7 @@ from tensorflow.python.platform import tf_logging
|
|||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
@ -2285,13 +2286,24 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
|||||||
def get_next(self):
|
def get_next(self):
|
||||||
return self._iterator.get_next()
|
return self._iterator.get_next()
|
||||||
|
|
||||||
|
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
|
"""Initialize underlying iterators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of any initializer ops that should be run.
|
||||||
|
"""
|
||||||
if eager_context.executing_eagerly():
|
if eager_context.executing_eagerly():
|
||||||
self._iterator = self._dataset.make_one_shot_iterator()
|
self._iterator = self._dataset.make_one_shot_iterator()
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
return [self._iterator.initializer]
|
return [self._iterator.initializer]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def initializer(self):
|
||||||
|
"""Returns a list of ops that initialize the iterator."""
|
||||||
|
return self.initialize()
|
||||||
|
|
||||||
# TODO(priyag): Delete this once all strategies use global batch size.
|
# TODO(priyag): Delete this once all strategies use global batch size.
|
||||||
@property
|
@property
|
||||||
def _global_batch_size(self):
|
def _global_batch_size(self):
|
||||||
|
@ -45,6 +45,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
|
|
||||||
|
|
||||||
def get_distributed_dataset(dataset,
|
def get_distributed_dataset(dataset,
|
||||||
@ -348,8 +349,7 @@ class DistributedIterator(object):
|
|||||||
class DistributedIteratorV1(DistributedIterator):
|
class DistributedIteratorV1(DistributedIterator):
|
||||||
"""Input Iterator for tf.data.DatasetV1."""
|
"""Input Iterator for tf.data.DatasetV1."""
|
||||||
|
|
||||||
# TODO(anjalisridhar): Move to using `initializer` instead to be consistent
|
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||||
# with tf.data iterator APIs.
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
"""Initialze underlying iterators.
|
"""Initialze underlying iterators.
|
||||||
|
|
||||||
@ -360,6 +360,7 @@ class DistributedIteratorV1(DistributedIterator):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def initializer(self):
|
def initializer(self):
|
||||||
|
"""Returns a list of ops that initialize the iterator."""
|
||||||
return self.initialize()
|
return self.initialize()
|
||||||
|
|
||||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||||
|
@ -171,9 +171,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
if iteration_type == "get_next":
|
if iteration_type == "get_next":
|
||||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
evaluate(control_flow_ops.group(iterator.initializer))
|
||||||
else:
|
|
||||||
evaluate(control_flow_ops.group(iterator._initializer))
|
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
@ -192,7 +190,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
evaluate(control_flow_ops.group(iterator.initializer))
|
||||||
else:
|
else:
|
||||||
evaluate(control_flow_ops.group(iterator._initializer))
|
evaluate(control_flow_ops.group(iterator._initializer))
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
|
|||||||
metric, args=(iterator.get_next(),)))
|
metric, args=(iterator.get_next(),)))
|
||||||
batches_per_update = distribution.num_replicas_in_sync
|
batches_per_update = distribution.num_replicas_in_sync
|
||||||
|
|
||||||
self.evaluate(iterator.initialize())
|
self.evaluate(iterator.initializer)
|
||||||
self.evaluate([v.initializer for v in metric.variables])
|
self.evaluate([v.initializer for v in metric.variables])
|
||||||
|
|
||||||
batches_consumed = 0
|
batches_consumed = 0
|
||||||
|
@ -124,7 +124,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
|||||||
# replace "distribution.num_replicas_in_sync" with "1".
|
# replace "distribution.num_replicas_in_sync" with "1".
|
||||||
batches_per_update = distribution.num_replicas_in_sync
|
batches_per_update = distribution.num_replicas_in_sync
|
||||||
|
|
||||||
self.evaluate(iterator.initialize())
|
self.evaluate(iterator.initializer)
|
||||||
self.evaluate(variables.local_variables_initializer())
|
self.evaluate(variables.local_variables_initializer())
|
||||||
|
|
||||||
batches_consumed = 0
|
batches_consumed = 0
|
||||||
|
@ -65,7 +65,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
def _get_iterator(self, strategy, input_fn):
|
def _get_iterator(self, strategy, input_fn):
|
||||||
iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
|
iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
|
||||||
self.evaluate(iterator.initialize())
|
self.evaluate(iterator.initializer)
|
||||||
return iterator
|
return iterator
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
|
@ -216,7 +216,7 @@ class MirroredVariableCreationTest(test.TestCase):
|
|||||||
|
|
||||||
iterator = distribution.make_input_fn_iterator(
|
iterator = distribution.make_input_fn_iterator(
|
||||||
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||||
self.evaluate(iterator.initialize())
|
self.evaluate(iterator.initializer)
|
||||||
features = iterator.get_next()
|
features = iterator.get_next()
|
||||||
|
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
|
@ -536,7 +536,7 @@ class ParameterServerStrategyTestBase(
|
|||||||
self.cached_session(config=config,
|
self.cached_session(config=config,
|
||||||
target=master_target) as sess:
|
target=master_target) as sess:
|
||||||
iterator = distribution.make_input_fn_iterator(input_fn)
|
iterator = distribution.make_input_fn_iterator(input_fn)
|
||||||
sess.run(iterator.initialize())
|
sess.run(iterator.initializer)
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
@ -554,7 +554,7 @@ class ParameterServerStrategyTestBase(
|
|||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
if test_reinitialize:
|
if test_reinitialize:
|
||||||
sess.run(iterator.initialize())
|
sess.run(iterator.initializer)
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
|
@ -55,7 +55,7 @@ class StandardInputStep(Step):
|
|||||||
self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
return self._iterator.initialize()
|
return self._iterator.initializer
|
||||||
|
|
||||||
|
|
||||||
class StandardSingleLossStep(StandardInputStep):
|
class StandardSingleLossStep(StandardInputStep):
|
||||||
|
@ -344,7 +344,7 @@ class DistributionTestBase(test.TestCase):
|
|||||||
test_reinitialize=True,
|
test_reinitialize=True,
|
||||||
ignore_order=False):
|
ignore_order=False):
|
||||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||||
evaluate(iterator.initialize())
|
evaluate(iterator.initializer)
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
@ -362,7 +362,7 @@ class DistributionTestBase(test.TestCase):
|
|||||||
|
|
||||||
# After re-initializing the iterator, should be able to iterate again.
|
# After re-initializing the iterator, should be able to iterate again.
|
||||||
if test_reinitialize:
|
if test_reinitialize:
|
||||||
evaluate(iterator.initialize())
|
evaluate(iterator.initializer)
|
||||||
|
|
||||||
for expected_value in expected_values:
|
for expected_value in expected_values:
|
||||||
next_element = iterator.get_next()
|
next_element = iterator.get_next()
|
||||||
@ -414,7 +414,7 @@ class DistributionTestBase(test.TestCase):
|
|||||||
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
i = strategy.make_dataset_iterator(ds)
|
i = strategy.make_dataset_iterator(ds)
|
||||||
|
|
||||||
self.evaluate(i.initialize())
|
self.evaluate(i.initializer)
|
||||||
|
|
||||||
def run_and_concatenate(strategy, i):
|
def run_and_concatenate(strategy, i):
|
||||||
x, y = strategy.experimental_run(lambda z: z, i)
|
x, y = strategy.experimental_run(lambda z: z, i)
|
||||||
|
@ -591,7 +591,7 @@ def get_iterator(dataset, distribution_strategy):
|
|||||||
|
|
||||||
def initialize_iterator(iterator, distribution_strategy):
|
def initialize_iterator(iterator, distribution_strategy):
|
||||||
with distribution_strategy.scope():
|
with distribution_strategy.scope():
|
||||||
init_op = control_flow_ops.group(iterator.initialize())
|
init_op = control_flow_ops.group(iterator.initializer)
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
K.get_session((init_op,)).run(init_op)
|
K.get_session((init_op,)).run(init_op)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user