Switch from deprecated strategy.distribute_dataset() to
strategy.make_input_fn_iterator() in contrib/distribute tests. PiperOrigin-RevId: 229752452
This commit is contained in:
parent
202525cd4b
commit
f01d357ca3
tensorflow/contrib/distribute/python
@ -95,8 +95,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
|
||||
with ops.Graph().as_default(), distribution.scope():
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_initializable_iterator()
|
||||
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||
if isinstance(distribution, tpu_strategy.TPUStrategy):
|
||||
def step_fn(ctx, inputs):
|
||||
value, update = distribution.extended.call_for_each_replica(
|
||||
@ -121,7 +120,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
# replace "distribution.num_replicas_in_sync" with "1".
|
||||
batches_per_update = distribution.num_replicas_in_sync
|
||||
|
||||
self.evaluate(iterator.initializer)
|
||||
self.evaluate(iterator.initialize())
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
|
||||
batches_consumed = 0
|
||||
|
@ -41,12 +41,9 @@ from tensorflow.python.ops.losses import losses_impl
|
||||
|
||||
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _get_iterator(self, ds):
|
||||
if context.executing_eagerly():
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = ds.make_initializable_iterator()
|
||||
self.evaluate(iterator.initializer)
|
||||
def _get_iterator(self, strategy, input_fn):
|
||||
iterator = strategy.make_input_fn_iterator(lambda _: input_fn())
|
||||
self.evaluate(iterator.initialize())
|
||||
return iterator
|
||||
|
||||
@combinations.generate(
|
||||
@ -70,7 +67,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(inputs,)))
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
def run_step():
|
||||
return distribution.extended.experimental_run_steps_on_iterator(
|
||||
@ -102,7 +99,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
def run_step():
|
||||
return distribution.group(
|
||||
@ -161,7 +158,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(inputs,)))
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
def run_step():
|
||||
return distribution.extended.experimental_run_steps_on_iterator(
|
||||
@ -230,7 +227,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||
return control_flow_ops.group(fetches)
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
def run_step():
|
||||
return distribution.extended.experimental_run_steps_on_iterator(
|
||||
@ -322,7 +319,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
distribution.extended.call_for_each_replica(
|
||||
model_fn, args=(inputs,)))
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
def run_step():
|
||||
return distribution.extended.experimental_run_steps_on_iterator(
|
||||
@ -413,7 +410,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
output=loss)
|
||||
return distribution.group(train_op)
|
||||
|
||||
iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
def run_step():
|
||||
initial_loss = lambda: constant_op.constant(1e7)
|
||||
|
@ -366,14 +366,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
|
||||
(layer2.kernel, layer2.bias),
|
||||
(layer3.kernel, layer3.bias)]
|
||||
|
||||
ds = distribution.distribute_dataset(
|
||||
lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||
if context.executing_eagerly():
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = ds.make_initializable_iterator()
|
||||
self.evaluate([iterator.initializer])
|
||||
|
||||
iterator = distribution.make_input_fn_iterator(
|
||||
lambda _: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
|
||||
self.evaluate(iterator.initialize())
|
||||
features = iterator.get_next()
|
||||
|
||||
with distribution.scope():
|
||||
|
@ -51,7 +51,7 @@ class Monitor(object):
|
||||
else:
|
||||
if session is None:
|
||||
raise ValueError("Should provide a `session` in Graph mode.")
|
||||
session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
|
||||
session.run(step_callable.initialize())
|
||||
self._run_step = session.make_callable(step_callable())
|
||||
session.run(variables.global_variables_initializer())
|
||||
|
||||
|
@ -41,12 +41,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
with distribution.scope():
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
ds = distribution.distribute_dataset(dataset_fn)
|
||||
if context.executing_eagerly():
|
||||
iterator = ds.make_one_shot_iterator()
|
||||
else:
|
||||
iterator = ds.make_initializable_iterator()
|
||||
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||
|
||||
def run_step():
|
||||
return control_flow_ops.group(
|
||||
@ -56,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
if not context.executing_eagerly():
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
sess.run(iterator.initialize())
|
||||
run_step = sess.make_callable(run_step())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
|
||||
|
||||
@ -33,6 +32,9 @@ class Step(object):
|
||||
def distribution(self):
|
||||
return self._distribution
|
||||
|
||||
def initialize(self):
|
||||
return []
|
||||
|
||||
def __call__(self):
|
||||
"""Perform one step of this training algorithm."""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
@ -50,12 +52,10 @@ class StandardInputStep(Step):
|
||||
|
||||
def __init__(self, dataset_fn, distribution):
|
||||
super(StandardInputStep, self).__init__(distribution)
|
||||
self._distributed_input = distribution.distribute_dataset(dataset_fn)
|
||||
if context.executing_eagerly():
|
||||
self._iterator = self._distributed_input.make_one_shot_iterator()
|
||||
else:
|
||||
# TODO(priyag): Expose initializer via some initializer property.
|
||||
self._iterator = self._distributed_input.make_initializable_iterator()
|
||||
self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||
|
||||
def initialize(self):
|
||||
return self._iterator.initialize()
|
||||
|
||||
|
||||
class StandardSingleLossStep(StandardInputStep):
|
||||
|
@ -46,10 +46,11 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
optimizer_fn, distribution, use_bias=True, iterations_per_step=2)
|
||||
|
||||
if context.executing_eagerly():
|
||||
single_loss_step.initialize()
|
||||
run_step = single_loss_step
|
||||
else:
|
||||
with self.cached_session() as sess:
|
||||
sess.run(single_loss_step._iterator.initializer)
|
||||
sess.run(single_loss_step.initialize())
|
||||
run_step = sess.make_callable(single_loss_step())
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user