Add tf.distribute.Strategy.experimental_make_numpy_iterator() function.

PiperOrigin-RevId: 228584021
This commit is contained in:
A. Unique TensorFlower 2019-01-09 14:04:07 -08:00 committed by TensorFlower Gardener
parent 2b2c042367
commit 0a88318ac7
28 changed files with 489 additions and 197 deletions

View File

@ -131,6 +131,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:numpy_dataset",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
@ -153,6 +154,7 @@ py_library(
"//tensorflow/python/distribute:cross_device_utils",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/distribute:numpy_dataset",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
],
@ -178,6 +180,7 @@ py_library(
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"//third_party/py/numpy",
],
)
@ -303,6 +306,7 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:numpy_dataset",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
],

View File

@ -28,6 +28,7 @@ from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
@ -86,6 +87,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
else:
local_devices = ("/device:CPU:0",)
self._worker_device = device_util.canonicalize("/device:CPU:0")
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
self._collective_keys = cross_device_utils.CollectiveKeys()
self._initialize_local(local_devices)
@ -121,6 +123,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
task_id)
self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
if num_gpus_per_worker:
local_devices = tuple(
"%s/device:GPU:%d" % (self._worker_device, i)
@ -157,6 +160,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
if colocate_with is None:
device_map = self._device_map
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
else:
device_map = colocate_with.device_map
logical_device = colocate_with.logical_device
@ -347,4 +353,4 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
# TODO(priyag): Delete this once all strategies use global batch size.
@property
def _global_batch_size(self):
return False
return True

View File

@ -466,6 +466,13 @@ class LocalCollectiveAllReduceStrategy(
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean_gradient_tape(distribution)
def testNumpyIterator(self):
num_gpus = 2
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
strategy, _, _ = self._get_test_object(None, None, num_gpus)
self._test_numpy_iterator(strategy)
if __name__ == '__main__':
test.main()

View File

@ -429,15 +429,6 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
class TestDistributionStrategyWithNumpyArrays(test.TestCase,
parameterized.TestCase):
@combinations.generate(strategy_for_numpy_input_combinations())
def test_creating_var_with_numpy_arrays(self, distribution):
with self.cached_session():
x = np.asarray(np.random.random((64, 3)), dtype=np.float32)
var_x = distributed_training_utils.get_var_for_numpy(distribution, x)
val = self.evaluate(var_x.value())
# Verify that the numpy value is copied to the variable.
self.assertAllEqual(x, val)
@combinations.generate(strategy_for_numpy_input_combinations())
def test_calculating_input_params_no_steps_no_batch_size(self, distribution):
# Calculate the per_replica_batch_size scaling factor for strategies
@ -576,26 +567,26 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
inputs = np.zeros((64, 3), dtype=np.float32)
targets = np.zeros((64, 4), dtype=np.float32)
inputs = np.zeros((64, 3), dtype=np.float32)
targets = np.zeros((64, 4), dtype=np.float32)
# Call fit with validation data
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
validation_data=(inputs, targets))
# Call fit with validation data
model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
validation_data=(inputs, targets))
# TODO(anjalisridhar): We need tests for when the batch size and steps are
# smaller and results in a 0 batch_size and steps value.
model.evaluate(inputs, targets)
# with steps
model.evaluate(inputs, targets, steps=2)
# with batch_size
model.evaluate(inputs, targets, batch_size=8)
# TODO(anjalisridhar): We need tests for when the batch size and steps
# are smaller and results in a 0 batch_size and steps value.
model.evaluate(inputs, targets)
# with steps
model.evaluate(inputs, targets, steps=2)
# with batch_size
model.evaluate(inputs, targets, batch_size=8)
model.predict(inputs)
# with steps
model.predict(inputs, steps=2)
# with batch_size
model.predict(inputs, batch_size=8)
model.predict(inputs)
# with steps
model.predict(inputs, steps=2)
# with batch_size
model.predict(inputs, batch_size=8)
@combinations.generate(strategy_for_numpy_input_combinations())
def test_calling_model_with_nested_numpy_arrays(self, distribution):
@ -1192,21 +1183,6 @@ class TestDistributionStrategyValidation(test.TestCase,
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics)
@combinations.generate(all_strategy_combinations_minus_default())
def test_loop_in_scope(self, distribution):
with self.cached_session():
with self.assertRaisesRegexp(
RuntimeError, 'should not be run inside the tf.distribute.Strategy'):
with distribution.scope():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
model.compile(optimizer, loss)
input_array = np.zeros((3, 3), dtype=np.float32)
model.predict(input_array)
if __name__ == '__main__':
test.main()

View File

@ -104,6 +104,36 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
auto_shard_dataset)
super(MirroredStrategy, self).__init__(extended)
# Override to change the documentation to reflect the different handling of
# global vs. local batch size between core and contrib.
def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation
self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None):
"""Makes an iterator for input provided via a nest of numpy arrays.
NOTE: The `batch_size` argument here has different behavior for this
contrib version of `MirroredStrategy`.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas.
batch_size: The number of entries from the array we should consume in one
step of the computation, across all replicas. This is the per-replica
batch size. The global batch size will be this times
`num_replicas_in_sync`.
num_epochs: The number of times to iterate through the examples. A value
of `None` means repeat forever.
shuffle: Size of buffer to use for shuffling the input examples.
Use `None` to disable shuffling.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
An `tf.distribute.InputIterator` which returns inputs for each step of the
computation. User should call `initialize` on the returned iterator.
"""
return super(MirroredStrategy, self).experimental_make_numpy_iterator(
numpy_input, batch_size, num_epochs, shuffle, session)
class MirroredExtended(CoreMirroredExtended):
"""Implementation of (contrib) MirroredStrategy."""

View File

@ -116,6 +116,9 @@ class MirroredTwoDeviceDistributionTest(
self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
expected_values)
def testNumpyIterator(self, distribution):
self._test_numpy_iterator(distribution)
def testGlobalStepUpdate(self, distribution):
self._test_global_step_update(distribution)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -50,8 +51,8 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
super(OneDeviceExtended, self).__init__(container_strategy)
self._device = device
self._default_device = device
worker = device_util.canonicalize("/device:CPU:0")
worker_device_pairs = [(worker, [self._device])]
self._input_device = device_util.canonicalize("/device:CPU:0")
worker_device_pairs = [(self._input_device, [self._device])]
device_map = values.SingleDeviceMap(device)
self._input_workers = input_lib.InputWorkers(
device_map, worker_device_pairs)
@ -82,6 +83,10 @@ class OneDeviceExtended(distribute_lib.DistributionStrategyExtended):
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, [distribute_lib.InputContext()])
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._input_device, session)
def _broadcast_to(self, tensor, destinations):
del destinations
return tensor

View File

@ -59,6 +59,10 @@ class OneDeviceStrategyTest(
self._test_input_fn_iterator(
iterator, d.extended.worker_devices, expected_values)
@test_util.run_in_graph_and_eager_modes
def testNumpyIterator(self):
self._test_numpy_iterator(self._get_distribution_strategy())
def testAllReduceSum(self):
self._test_all_reduce_sum(self._get_distribution_strategy())

View File

@ -89,6 +89,37 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
super(ParameterServerStrategy, self).__init__(
ParameterServerExtended(self, num_gpus_per_worker))
# Override to change the documentation to reflect the different handling of
# global vs. local batch size between core and contrib.
def experimental_make_numpy_iterator( # pylint: disable=useless-super-delegation
self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None):
"""Makes an iterator for input provided via a nest of numpy arrays.
NOTE: The `batch_size` argument here has different behavior for this
contrib version of `ParameterServerStrategy`.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas.
batch_size: The number of entries from the array we should consume in one
step of the computation, across all replicas. This is the per-replica
batch size. The global batch size will be this times
`num_replicas_in_sync`.
num_epochs: The number of times to iterate through the examples. A value
of `None` means repeat forever.
shuffle: Size of buffer to use for shuffling the input examples.
Use `None` to disable shuffling.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
An `tf.distribute.InputIterator` which returns inputs for each step of the
computation. User should call `initialize` on the returned iterator.
"""
return super(ParameterServerStrategy,
self).experimental_make_numpy_iterator(
numpy_input, batch_size, num_epochs, shuffle, session)
class ParameterServerExtended(CoreParameterServerExtended):
"""Implementation of ParameterServerStrategy."""

View File

@ -893,5 +893,17 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
strategy.extended.call_for_each_replica(f)
class LocalParameterServerStrategyTest(strategy_test_lib.DistributionTestBase,
parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager'],
use_core_strategy=[True, False],
required_gpus=2))
def testNumpyIterator(self, use_core_strategy):
strategy, _, _ = create_test_objects(
num_gpus=2, use_core_strategy=use_core_strategy)
self._test_numpy_iterator(strategy)
if __name__ == '__main__':
test.main()

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context as ds_context
@ -295,6 +297,33 @@ class DistributionTestBase(test.TestCase):
global_step_values = self.evaluate(global_step_tensors)
self.assertEqual((1,) * len(global_step_tensors), global_step_values)
def _test_numpy_iterator(self, strategy):
with strategy.scope(), self.cached_session() as sess:
x = np.asarray([[1, 2], [6, 12], [2, 4],
[5, 10], [3, 6], [4, 8]])
y = np.asarray([5, 4, 3, 2, 1, 0])
batch_size = 6
if not strategy.extended._global_batch_size: # pylint: disable=protected-access
batch_size = batch_size // strategy.num_replicas_in_sync
i = strategy.experimental_make_numpy_iterator(
(x, y), batch_size=batch_size, num_epochs=2, shuffle=None,
session=sess)
self.evaluate(i.initialize())
def run_and_concatenate(strategy, i):
x, y = strategy.experimental_run(lambda z: z, i)
x, y = self.evaluate((strategy.unwrap(x), strategy.unwrap(y)))
return np.concatenate(x), np.concatenate(y)
x_1, y_1 = run_and_concatenate(strategy, i)
self.assertAllEqual(x, x_1)
self.assertAllEqual(y, y_1)
x_2, y_2 = run_and_concatenate(strategy, i)
self.assertAllEqual(x, x_2)
self.assertAllEqual(y, y_2)
with self.assertRaises(errors.OutOfRangeError):
run_and_concatenate(strategy, i)
class OneDeviceDistributionTestBase(test.TestCase):
"""Some tests that should work with any one-device DistributionStrategy."""

View File

@ -34,6 +34,7 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver_lib
@ -303,6 +304,11 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
functools.partial(self._call_dataset_fn, dataset_fn),
self._input_workers)
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, numpy_dataset.SingleDevice(self.get_host_cpu_device(0)),
session)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
@ -466,6 +472,9 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
if colocate_with is None:
device_map = self._device_map
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
else:
device_map = colocate_with.device_map
logical_device = colocate_with.logical_device

View File

@ -124,6 +124,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":device_util",
":numpy_dataset",
":reduce_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
@ -221,6 +222,7 @@ py_library(
":distribute_lib",
":input_lib",
":multi_worker_util",
":numpy_dataset",
":reduce_util",
":shared_variable_creator",
":values",
@ -249,6 +251,7 @@ py_library(
deps = [
":input_lib",
":mirrored_strategy",
":numpy_dataset",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
@ -276,6 +279,35 @@ py_library(
],
)
py_library(
name = "numpy_dataset",
srcs = ["numpy_dataset.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
],
)
py_test(
name = "numpy_dataset_test",
size = "small",
srcs = ["numpy_dataset_test.py"],
srcs_version = "PY2AND3",
deps = [
":numpy_dataset",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:test",
"//third_party/py/numpy",
],
)
py_library(
name = "input_lib",
srcs = ["input_lib.py"],

View File

@ -26,6 +26,7 @@ import enum
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context as eager_context
from tensorflow.python.framework import constant_op
@ -360,7 +361,7 @@ class DistributionStrategy(object):
return self._extended._distribute_dataset(dataset_fn) # pylint: disable=protected-access
def make_dataset_iterator(self, dataset):
"""Makes an iterator for input provided via input_dataset.
"""Makes an iterator for input provided via `dataset`.
Data from the given dataset will be distributed evenly across all the
compute replicas. We will assume that the input dataset is batched by the
@ -418,6 +419,40 @@ class DistributionStrategy(object):
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
input_fn, replication_mode=replication_mode)
def experimental_make_numpy_iterator(
self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None):
"""Makes an iterator for input provided via a nest of numpy arrays.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked,
as that is normal `tf.data.Dataset` behavior.
batch_size: The number of entries from the array we should consume in one
step of the computation, across all replicas. This is the global batch
size. It should be divisible by `num_replicas_in_sync`.
num_epochs: The number of times to iterate through the examples. A value
of `None` means repeat forever.
shuffle: Size of buffer to use for shuffling the input examples.
Use `None` to disable shuffling.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
An `tf.distribute.InputIterator` which returns inputs for each step of the
computation. User should call `initialize` on the returned iterator.
"""
ds = self.extended.experimental_make_numpy_dataset(
numpy_input, session=session)
if shuffle:
ds = ds.shuffle(shuffle)
if num_epochs != 1:
ds = ds.repeat(num_epochs)
# We need to use the drop_remainder argument to get a known static
# input shape which is required for TPUs.
drop_remainder = self.extended.experimental_require_static_shapes
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
return self.make_dataset_iterator(ds)
def experimental_run(self, fn, input_iterator=None):
"""Runs ops in `fn` on each replica, with inputs from `input_iterator`.
@ -1083,6 +1118,29 @@ class DistributionStrategyExtended(object):
def _make_input_fn_iterator(self, input_fn, replication_mode):
raise NotImplementedError("must be implemented in descendants")
def experimental_make_numpy_dataset(self, numpy_input, session=None):
"""Makes a dataset for input provided via a numpy array.
This avoids adding `numpy_input` as a large constant in the graph,
and copies the data to the machine or machines that will be processing
the input.
Args:
numpy_input: A nest of NumPy input arrays that will be distributed evenly
across all replicas. Note that lists of Numpy arrays are stacked,
as that is normal `tf.data.Dataset` behavior.
session: (TensorFlow v1.x graph execution only) A session used for
initialization.
Returns:
A `tf.data.Dataset` representing `numpy_input`.
"""
_require_cross_replica_context_extended(self)
return self._experimental_make_numpy_dataset(numpy_input, session=session)
def _experimental_make_numpy_dataset(self, numpy_input, session):
raise NotImplementedError("must be implemented in descendants")
def broadcast_to(self, tensor, destinations):
"""Mirror a tensor on one device to all worker devices.
@ -1660,6 +1718,18 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
replication_mode=InputReplicationMode.PER_WORKER):
return input_fn(InputContext()).make_initializable_iterator()
def _experimental_make_numpy_dataset(self, numpy_input, session):
numpy_flat = nest.flatten(numpy_input)
vars_flat = tuple(
variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
trainable=False, use_resource=True)
for i in numpy_flat
)
for v, i in zip(vars_flat, numpy_flat):
numpy_dataset.init_var_from_numpy(v, i, session)
vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
return dataset_ops.Dataset.from_tensor_slices(vars_nested)
def _broadcast_to(self, tensor, destinations):
if destinations is None:
return tensor

View File

@ -29,6 +29,7 @@ from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import shared_variable_creator
from tensorflow.python.distribute import values
@ -460,6 +461,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
self._input_workers = input_lib.InputWorkers(self._device_map)
self._inferred_cross_device_ops = cross_device_ops_lib.choose_the_best(
devices)
self._host_input_device = numpy_dataset.SingleDevice("/cpu:0")
def _initialize_multi_worker(self, devices):
"""Initializes the object for multi-worker training."""
@ -488,6 +490,7 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
# their ops will end up on the cpu device of its first worker, e.g.
# "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
self._default_device = workers[0]
self._host_input_device = numpy_dataset.SingleDevice(workers[0])
self._device_map = values.ReplicaDeviceMap(devices)
self._input_workers = input_lib.InputWorkers(
@ -501,6 +504,9 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
if colocate_with is None:
device_map = self._device_map
logical_device = 0 # TODO(josh11b): Get logical device from scope here.
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
else:
device_map = colocate_with.device_map
logical_device = colocate_with.logical_device
@ -571,6 +577,10 @@ class MirroredExtended(distribute_lib.DistributionStrategyExtended):
return input_lib.InputFunctionIterator(
input_fn, self._input_workers, input_contexts)
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._host_input_device, session)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None):

View File

@ -0,0 +1,97 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Code for creating a dataset out of a NumPy array."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
def init_var_from_numpy(input_var, numpy_input, session):
"""Initialize `input_var` to `numpy_input` using `session` in graph mode."""
with ops.init_scope():
if context.executing_eagerly():
input_var.assign(numpy_input)
return
assert session is not None
session.run(input_var.initializer)
start_placeholder = array_ops.placeholder(dtypes.int64, ())
end_placeholder = array_ops.placeholder(dtypes.int64, ())
slice_placeholder = array_ops.placeholder(input_var.dtype)
assign_slice_op = input_var[start_placeholder:end_placeholder].assign(
slice_placeholder)
# If each batch element is > 64 MB, then we copy each batch element
# individually. Otherwise, the slices will be < 128 MB. There might be
# padding which might mean that the slices are 128 MB even if the size of
# the tensor allocated is less than 128 MB. This formula gives slices with
# size: ceil(64 MB / byte size per batch element) bytes. Using ceil()
# guarantees we get a number >= 1.
# Calculate the size of each batch element.
byte_size_per_batch_element = (
np.prod(numpy_input.shape[1:]) * input_var.dtype.size)
# Calculate number of elements we want to copy per slice.
batch_size_per_slice = int(
np.ceil((64 << 20) / byte_size_per_batch_element))
# Copy slices of the above size starting at 0, except the last slice will be
# smaller.
start = 0
limit = numpy_input.shape[0]
while start < limit:
end = min(start + batch_size_per_slice, limit)
session.run(assign_slice_op, feed_dict={
start_placeholder: start,
end_placeholder: end,
slice_placeholder: numpy_input[start:end]})
start = end
def one_host_numpy_dataset(numpy_input, colocate_with, session):
"""Create a dataset on `colocate_with` from `numpy_input`."""
def create_colocated_variable(next_creator, *args, **kwargs):
kwargs["colocate_with"] = colocate_with
return next_creator(*args, **kwargs)
numpy_flat = nest.flatten(numpy_input)
with variable_scope.variable_creator_scope(create_colocated_variable):
vars_flat = tuple(variable_scope.variable(array_ops.zeros(i.shape, i.dtype),
trainable=False)
for i in numpy_flat)
for v, i in zip(vars_flat, numpy_flat):
init_var_from_numpy(v, i, session)
vars_nested = nest.pack_sequence_as(numpy_input, vars_flat)
return dataset_ops.Dataset.from_tensor_slices(vars_nested)
class SingleDevice(object):
"""Used with `colocate_with` to create a non-mirrored variable."""
def __init__(self, device):
self.device = device

View File

@ -0,0 +1,44 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for numpy_dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
class InitVarFromNumpyTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_creating_var_with_numpy_arrays(self):
with self.cached_session() as session:
x = np.asarray(np.random.random((64, 3)), dtype=np.float32)
initial = np.zeros_like(x)
var_x = variable_scope.variable(initial)
numpy_dataset.init_var_from_numpy(var_x, x, session)
val = self.evaluate(var_x.value())
# Verify that the numpy value is copied to the variable.
self.assertAllEqual(x, val)
if __name__ == '__main__':
test.main()

View File

@ -26,6 +26,7 @@ from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import values
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
@ -137,6 +138,7 @@ class ParameterServerStrategyExtended(
assert cluster_spec.as_dict()
worker_device = "/job:%s/task:%d" % (task_type, task_id)
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
# Define compute devices which is a list of device strings and one for each
# replica. When there are GPUs, replicate operations on these GPUs.
@ -195,6 +197,7 @@ class ParameterServerStrategyExtended(
def _initialize_local(self, cluster_resolver):
"""Initialize internal devices for local training."""
worker_device = device_util.canonicalize("/device:CPU:0")
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
num_gpus = cluster_resolver.num_accelerators()
# Define compute devices which is a list of device strings and one for each
# replica. When there are GPUs, replicate operations on these GPUs.
@ -262,6 +265,10 @@ class ParameterServerStrategyExtended(
return input_lib.InputFunctionIterator(input_fn, self._input_workers,
[input_context])
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._input_host_device, session)
def _broadcast_to(self, tensor, destinations):
# This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it
@ -329,8 +336,12 @@ class ParameterServerStrategyExtended(
var_creator = next_creator
if "colocate_with" in kwargs:
colocate_with = kwargs["colocate_with"]
if isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return var_creator(*args, **kwargs)
with ops.device(None):
with ops.colocate_with(kwargs["colocate_with"]):
with ops.colocate_with(colocate_with):
return var_creator(*args, **kwargs)
with ops.colocate_with(None, ignore_existing=True):

View File

@ -34,25 +34,13 @@ from tensorflow.python.keras import callbacks
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training.mode_keys import ModeKeys
from tensorflow.python.util import nest
def validate_not_in_strategy_scope():
"""Validate fit/eval/predict are not running in DS scope."""
if distribution_strategy_context.has_distribution_strategy():
if distribution_strategy_context.in_cross_replica_context():
raise RuntimeError(
'Fit/Eval/Predict should not be run inside the tf.distribute.Strategy'
' scope. Only model creation and compilation should be in '
'tf.distribute.Strategy scope.')
def set_weights(distribution_strategy, dist_model, weights):
"""Sets the weights of the replicated models.
@ -530,100 +518,11 @@ def get_batch_dimension(iterator):
return dims[0] if dims else None
def get_cpu_device(distribution_strategy):
"""Returns the CPU device of the TPU host or the default CPU device string.
Args:
distribution_strategy: The DistributionStrategy used to compile the model.
Returns:
A device string which is the TPU host's CPU device in case of
TPUDistributionStrategy or the default CPU device string in all other
cases.
Raises:
NotImplementedError: We currently don't support copying numpy data to
multiple hosts in the case of Cloud TPU pods.
"""
if is_tpu_strategy(distribution_strategy):
if distribution_strategy.extended.num_hosts > 1:
raise NotImplementedError('TPUDistributionStrategy does not '
'support numpy inputs when running on Cloud'
'TPU pods.')
return distribution_strategy.extended.get_host_cpu_device(0)
else:
# For all strategies except TPUDistributionStrategy
# TODO(anjalisridhar): We may need to modify this when we add support for
# multi-worker strategy.
return '/CPU:0'
def get_var_for_numpy(distribution_strategy, x):
if isinstance(x, list):
var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input)
for single_input in x])
else:
var_x = _get_var_for_numpy(distribution_strategy, x)
return var_x
def _get_var_for_numpy(distribution_strategy, input_array):
"""Creates a variable and assigns the value of the numpy array to it.
Args:
distribution_strategy: The DistributionStrategy used to compile the model.
input_array: The input numpy array whose value will be assigned to the
variable we create.
Returns:
The variable to which we will copy the value of the input numpy array.
"""
with ops.device(get_cpu_device(distribution_strategy)):
# Create and initialize a variable on the CPU device. This is the CPU
# device of the host in the case of TPUDistributionStrategy.
input_var = variables.VariableV1(array_ops.zeros(input_array.shape,
input_array.dtype),
trainable=False, use_resource=True)
K.get_session().run(input_var.initializer)
# Create a placeholder for the numpy array input slices. We copy the value
# of the input numpy array to the variable in slices of size 64 MB to avoid
# running into memory issues or RPC message limits.
start_placeholder = array_ops.placeholder(dtypes.int64, ())
end_placeholder = array_ops.placeholder(dtypes.int64, ())
slice_placeholder = array_ops.placeholder(input_var.dtype)
assign_slice_op = input_var[start_placeholder:end_placeholder].assign(
slice_placeholder)
# If each batch element is > 64 MB, then we copy each batch element
# individually. Otherwise, the slices will be < 128 MB. There might be padding
# which might mean that the slices are 128 MB even if the size of the
# tensor allocated is less than 128 MB.
# This formula gives slices with size:
# ceil(64 MB / byte size per batch element) bytes.
# Using ceil() guarantees we get a number >= 1.
# Calculate the size of each batch element.
byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \
input_var.dtype.size
# Calculate number of elements we want to copy per slice.
batch_size_per_slice = int(np.ceil((64 << 20) / byte_size_per_batch_element))
# Copy slices of the above size starting at 0, except the last slice will be
# smaller.
start = 0
limit = input_array.shape[0]
while start < limit:
end = min(start + batch_size_per_slice, limit)
K.get_session().run(assign_slice_op, feed_dict={
start_placeholder: start,
end_placeholder: end,
slice_placeholder: input_array[start:end]})
start = end
return input_var
def list_to_tuple(maybe_list):
"""Datasets treat lists specially, so switch them to tuples."""
if isinstance(maybe_list, list):
return tuple(maybe_list)
return maybe_list
def _get_input_from_iterator(iterator, model):

View File

@ -2161,53 +2161,47 @@ class Model(Network):
'you should specify the `{steps_name}` argument.'
.format(steps_name=steps_name))
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# We need to use the drop_remainder argument to allow for a static
# input shape which is required for TPUs.
drop_remainder = self._distribution_strategy.require_static_shapes
if y is not None:
var_x = distributed_training_utils.get_var_for_numpy(
self._distribution_strategy, x)
var_y = distributed_training_utils.get_var_for_numpy(
self._distribution_strategy, y)
if sample_weight is not None:
var_sample_weights = distributed_training_utils.get_var_for_numpy(
self._distribution_strategy, sample_weight)
if ops.executing_eagerly_outside_functions():
session = None
else:
session = K.get_session()
x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y,
var_sample_weights))
with self._distribution_strategy.scope():
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
x = distributed_training_utils.list_to_tuple(x)
if y is not None:
y = distributed_training_utils.list_to_tuple(y)
if sample_weight is not None:
sample_weight = distributed_training_utils.list_to_tuple(
sample_weight)
in_tuple = (x, y, sample_weight)
else:
in_tuple = (x, y)
else:
x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y))
in_tuple = x
if shuffle:
# 1024 is a good buffer size since it is much larger than the average
# batch size provided by the user and provides sufficient randomness.
# One thing to keep in mind is the memory usage based on the size of
# each sample.
x = x.shuffle(1024)
x = x.repeat()
x = x.batch(batch_size, drop_remainder=drop_remainder)
y = None
sample_weight = None
shuffle_buffer = 1024
else:
shuffle_buffer = None
iterator = self._distribution_strategy.experimental_make_numpy_iterator(
in_tuple, batch_size, num_epochs=None, shuffle=shuffle_buffer,
session=session)
else:
# This case is for the predict call where the dataset only contains
# inputs and no targets, i.e. it does not return a tuple
var_x = distributed_training_utils.get_var_for_numpy(
self._distribution_strategy, x)
x = dataset_ops.Dataset.from_tensor_slices(var_x)
x = x.batch(batch_size, drop_remainder=drop_remainder)
assert isinstance(x, dataset_ops.DatasetV2)
training_utils.validate_dataset_input(x, y, sample_weight,
validation_split)
iterator = self._distribution_strategy.make_dataset_iterator(x)
assert isinstance(x, dataset_ops.DatasetV2)
with self._distribution_strategy.scope():
iterator = self._distribution_strategy.make_dataset_iterator(x)
init_op = iterator.initialize()
if not context.executing_eagerly():
K.get_session().run(init_op)
training_utils.validate_dataset_input(x, y, sample_weight,
validation_split)
return iterator
def _standardize_user_data(self,

View File

@ -59,8 +59,6 @@ def fit_distributed(model,
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
steps_per_epoch, batch_size = (
distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps_per_epoch,
@ -141,8 +139,6 @@ def evaluate_distributed(model,
distributed_training_utils.validate_inputs(x, y, model._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
steps, batch_size = distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps, batch_size)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
@ -179,8 +175,6 @@ def predict_distributed(model,
x, None, model._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
# TODO(b/122314600): Remove the scope validate.
distributed_training_utils.validate_not_in_strategy_scope()
steps, batch_size = distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps, batch_size)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)

View File

@ -554,14 +554,15 @@ class Optimizer(
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
# Handle DistributionStrategy case.
if distribute_ctx.get_cross_replica_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-replica context.")
# TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
# TODO(isaprykin): Get rid of `has_strategy()` check by
# always calling _distributed_apply(), using the default distribution
# as needed.
if distribute_ctx.has_distribution_strategy():
if distribute_ctx.has_strategy():
# Handle DistributionStrategy case.
if distribute_ctx.in_cross_replica_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-replica context.")
grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
return distribute_ctx.get_replica_context().merge_call(
self._distributed_apply, args=(grads_and_vars, global_step, name))

View File

@ -75,6 +75,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_iterator"
argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], "
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -50,6 +50,10 @@ tf_class {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_dataset"
argspec: "args=[\'self\', \'numpy_input\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_steps_on_iterator"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "

View File

@ -74,6 +74,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_iterator"
argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], "
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -75,6 +75,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_iterator"
argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], "
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -50,6 +50,10 @@ tf_class {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_dataset"
argspec: "args=[\'self\', \'numpy_input\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_run_steps_on_iterator"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "

View File

@ -74,6 +74,10 @@ tf_class {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_make_numpy_iterator"
argspec: "args=[\'self\', \'numpy_input\', \'batch_size\', \'num_epochs\', \'shuffle\', \'session\'], varargs=None, keywords=None, defaults=[\'1\', \'1024\', \'None\'], "
}
member_method {
name: "experimental_run"
argspec: "args=[\'self\', \'fn\', \'input_iterator\'], varargs=None, keywords=None, defaults=[\'None\'], "