Switch from deprecated strategy.distribute_dataset() to

strategy.make_input_fn_iterator() in contrib/distribute tests.

PiperOrigin-RevId: 229752452
This commit is contained in:
A. Unique TensorFlower 2019-01-17 08:34:02 -08:00 committed by TensorFlower Gardener
parent 202525cd4b
commit f01d357ca3
7 changed files with 26 additions and 39 deletions

View File

@ -95,8 +95,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
dataset_fn).make_initializable_iterator()
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
if isinstance(distribution, tpu_strategy.TPUStrategy):
def step_fn(ctx, inputs):
value, update = distribution.extended.call_for_each_replica(
@ -121,7 +120,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
# replace "distribution.num_replicas_in_sync" with "1".
batches_per_update = distribution.num_replicas_in_sync
self.evaluate(iterator.initializer)
self.evaluate(iterator.initialize())
self.evaluate(variables.local_variables_initializer())
batches_consumed = 0

View File

@ -41,12 +41,9 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def _get_iterator(self, ds):
if context.executing_eagerly():
iterator = ds.make_one_shot_iterator()
else:
iterator = ds.make_initializable_iterator()
self.evaluate(iterator.initializer)
def _get_iterator(self, strategy, input_fn):
iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
self.evaluate(iterator.initialize())
return iterator
@combinations.generate(
@ -70,7 +67,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.extended.call_for_each_replica(
model_fn, args=(inputs,)))
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
iterator = self._get_iterator(distribution, dataset_fn)
def run_step():
return distribution.extended.experimental_run_steps_on_iterator(
@ -102,7 +99,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
iterator = self._get_iterator(distribution, dataset_fn)
def run_step():
return distribution.group(
@ -161,7 +158,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.extended.call_for_each_replica(
model_fn, args=(inputs,)))
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
iterator = self._get_iterator(distribution, dataset_fn)
def run_step():
return distribution.extended.experimental_run_steps_on_iterator(
@ -230,7 +227,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
return control_flow_ops.group(fetches)
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
iterator = self._get_iterator(distribution, dataset_fn)
def run_step():
return distribution.extended.experimental_run_steps_on_iterator(
@ -322,7 +319,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.extended.call_for_each_replica(
model_fn, args=(inputs,)))
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
iterator = self._get_iterator(distribution, dataset_fn)
def run_step():
return distribution.extended.experimental_run_steps_on_iterator(
@ -413,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
iterator = self._get_iterator(distribution, dataset_fn)
def run_step():
initial_loss = lambda: constant_op.constant(1e7)

View File

@ -366,14 +366,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
(layer2.kernel, layer2.bias),
(layer3.kernel, layer3.bias)]
ds = distribution.distribute_dataset(
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
if context.executing_eagerly():
iterator = ds.make_one_shot_iterator()
else:
iterator = ds.make_initializable_iterator()
self.evaluate([iterator.initializer])
iterator = distribution.make_input_fn_iterator(
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
self.evaluate(iterator.initialize())
features = iterator.get_next()
with distribution.scope():

View File

@ -51,7 +51,7 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
session.run(step_callable.initialize())
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())

View File

@ -41,12 +41,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
with distribution.scope():
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
ds = distribution.distribute_dataset(dataset_fn)
if context.executing_eagerly():
iterator = ds.make_one_shot_iterator()
else:
iterator = ds.make_initializable_iterator()
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
def run_step():
return control_flow_ops.group(
@ -56,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
sess.run(iterator.initializer)
sess.run(iterator.initialize())
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@ -33,6 +32,9 @@ class Step(object):
def distribution(self):
return self._distribution
def initialize(self):
return []
def __call__(self):
"""Perform one step of this training algorithm."""
raise NotImplementedError("must be implemented in descendants")
@ -50,12 +52,10 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
if context.executing_eagerly():
self._iterator = self._distributed_input.make_one_shot_iterator()
else:
# TODO(priyag): Expose initializer via some initializer property.
self._iterator = self._distributed_input.make_initializable_iterator()
self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
def initialize(self):
return self._iterator.initialize()
class StandardSingleLossStep(StandardInputStep):

View File

@ -46,10 +46,11 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
optimizer_fn, distribution, use_bias=True, iterations_per_step=2)
if context.executing_eagerly():
single_loss_step.initialize()
run_step = single_loss_step
else:
with self.cached_session() as sess:
sess.run(single_loss_step._iterator.initializer)
sess.run(single_loss_step.initialize())
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())