Automatically use single core for stateful RNN in Keras TPU.

PiperOrigin-RevId: 211532963
This commit is contained in:
Jianwei Xie 2018-09-04 15:11:28 -07:00 committed by TensorFlower Gardener
parent ee24255e3d
commit 4fbc4e5b98

View File

@ -170,11 +170,41 @@ class TPUDistributionStrategy(object):
worker_re = re.compile('/job:([^/]+)')
for device in metadata.devices:
if 'TPU:0' in device.name:
self.worker_name = worker_re.search(device.name).group(1)
self._worker_name = worker_re.search(device.name).group(1)
break
def _make_assignment_for_model(self, cpu_model):
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
num_cores = self._num_cores
if num_cores > 1 and cpu_model.stateful:
logging.warning(
'Model replication does not currently support stateful models. '
'Degrading to a single core.')
num_cores = 1
return TPUAssignment(
worker_name=self._worker_name, num_cores=num_cores)
class TPUAssignment(object):
"""This is object holding TPU resources assignment for the concrete model.
`TPUDistributionStrategy` is responsible to create the instance of
`TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
model and input batch sizes.
"""
def __init__(self, worker_name, num_cores):
self._worker_name = worker_name
self._num_cores = num_cores
@property
def worker_name(self):
return self._worker_name
@property
def num_towers(self):
# TODO(xiejw): Support automatically assign num_cores based on inputs.
return self._num_cores
@ -495,8 +525,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_dict[tensor] = value
return infeed_dict
def __init__(self, distribution_strategy):
self._strategy = distribution_strategy
def __init__(self, tpu_assignment):
self._tpu_assignment = tpu_assignment
def _split_tensors(self, inputs):
"""Split input data across shards.
@ -509,16 +539,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
Returns:
List of lists containing the input to feed to each TPU shard.
"""
if self._strategy.num_towers == 1:
if self._tpu_assignment.num_towers == 1:
return [inputs]
batch_size = inputs[0].shape[0]
assert batch_size % self._strategy.num_towers == 0, (
'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
(batch_size, self._strategy.num_towers))
shard_size = batch_size // self._strategy.num_towers
assert batch_size % self._tpu_assignment.num_towers == 0, (
'batch_size must be divisible by the number of TPU cores in use (%s '
'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
shard_size = batch_size // self._tpu_assignment.num_towers
input_list = []
for index in range(self._strategy.num_towers):
for index in range(self._tpu_assignment.num_towers):
shard_inputs = [
x[index * shard_size:(index + 1) * shard_size] for x in inputs
]
@ -533,8 +563,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_op = []
shard_infeed_tensors = []
for shard_id in range(self._strategy.num_towers):
with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
for shard_id in range(self._tpu_assignment.num_towers):
with ops.device(
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_tensors = []
with ops.device('/device:TPU:%d' % shard_id):
for spec in input_specs:
@ -573,30 +604,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# TODO(saeta): Verify tpu_model_op is as expected!
return {}
def __init__(self, dataset, distribution_strategy, tpu_session):
# pylint: disable=redefined-outer-name
def __init__(self, dataset, tpu_assignment, tpu_session):
"""Constructs a TPUDatasetInfeedManager.
Must be called within a `KerasTPUModel.tpu_session` context!
Args:
dataset: A `tf.data.Dataset` to infeed.
distribution_strategy: The `TPUDistributionStrategy` used to configure the
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
tpu_session: The `tf.Session` object used for running the TPU model.
"""
self._verify_dataset_shape(dataset)
self._dataset = dataset
self._strategy = distribution_strategy
self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
dummy_x_shape[0] *= distribution_strategy.num_towers
dummy_x_shape[0] *= tpu_assignment.num_towers
dummy_y_shape = dataset.output_shapes[1].as_list()
dummy_y_shape[0] *= distribution_strategy.num_towers
dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
tpu_session.run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
for i in range(distribution_strategy.num_towers):
for i in range(tpu_assignment.num_towers):
with ops.control_dependencies(ctrl_deps): # Ensure deterministic
# TODO(saeta): Ensure correct placement!
get_next_op = self._iterator.get_next()
@ -676,10 +708,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
def build_infeed_from_input_specs(self, input_specs, execution_mode):
shard_infeed_tensors = self._get_next_ops
assert len(shard_infeed_tensors) == self._strategy.num_towers
assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
infeed_ops = []
for shard_id in range(self._strategy.num_towers):
with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
for shard_id in range(self._tpu_assignment.num_towers):
with ops.device(
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_ops.append(
tpu_ops.infeed_enqueue_tuple(
shard_infeed_tensors[shard_id],
@ -702,10 +735,10 @@ class TPUFunction(object):
instead of being injected as `feed_dict` items or fetches.
"""
def __init__(self, model, execution_mode, strategy):
def __init__(self, model, execution_mode, tpu_assignment):
self.model = model
self.execution_mode = execution_mode
self._strategy = strategy
self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
@ -757,7 +790,8 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
with keras_tpu_variables.replicated_scope(self._strategy.num_towers):
with keras_tpu_variables.replicated_scope(
self._tpu_assignment.num_towers):
self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
@ -827,7 +861,7 @@ class TPUFunction(object):
# `execute op` replicates `_model_fn` `num_replicas` times, with each shard
# running on a different logical core.
compile_op, execute_op = tpu.split_compile_and_replicate(
_model_fn, inputs=[[]] * self._strategy.num_towers)
_model_fn, inputs=[[]] * self._tpu_assignment.num_towers)
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
@ -835,8 +869,9 @@ class TPUFunction(object):
input_specs, self.execution_mode)
# Build output ops.
outfeed_op = []
for shard_id in range(self._strategy.num_towers):
with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
for shard_id in range(self._tpu_assignment.num_towers):
with ops.device(
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
outfeed_op.extend(
tpu_ops.outfeed_dequeue_tuple(
dtypes=[spec.dtype for spec in self._outfeed_spec],
@ -886,7 +921,7 @@ class TPUFunction(object):
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
return mgr
return TPUNumpyInfeedManager(self.model._strategy)
return TPUNumpyInfeedManager(self.model._tpu_assignment)
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
"""Looks up the corresponding `TPUModelOp` for a given `input_specs`.
@ -958,7 +993,7 @@ class TPUFunction(object):
outputs = [[]] * len(self._outfeed_spec)
outputs_per_replica = len(self._outfeed_spec)
for i in range(self._strategy.num_towers):
for i in range(self._tpu_assignment.num_towers):
output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
outputs_per_replica]
for j in range(outputs_per_replica):
@ -967,7 +1002,7 @@ class TPUFunction(object):
return [np.concatenate(group) for group in outputs]
else:
return outfeed_outputs[:len(outfeed_outputs) //
self._strategy.num_towers]
self._tpu_assignment.num_towers]
def __call__(self, inputs):
"""__call__ executes the function on the computational hardware.
@ -1119,11 +1154,11 @@ class KerasTPUModel(models.Model):
self.predict_function = None
self.test_function = None
self.train_function = None
self._strategy = strategy
cluster_resolver = self._strategy._tpu_cluster_resolver
cluster_resolver = strategy._tpu_cluster_resolver
self._tpu_name_or_address = cluster_resolver.get_master()
self._cpu_model = cpu_model
self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
self._tpu_model = None
self._tpu_weights_initialized = False
@ -1146,7 +1181,7 @@ class KerasTPUModel(models.Model):
return {
'cpu_model': self._cpu_model,
'tpu_name_or_address': self._tpu_name_or_address,
'strategy': self._strategy,
'tpu_assignment': self._tpu_assignment,
}
def compile(self,
@ -1207,7 +1242,7 @@ class KerasTPUModel(models.Model):
'/keras')
if callable(x):
with self.tpu_session() as sess,\
ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@ -1215,7 +1250,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@ -1236,7 +1272,8 @@ class KerasTPUModel(models.Model):
if validation_steps is None:
raise ValueError('When using tf.data as validation for a model, you '
'should specify the validation_steps argument.')
infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
val_x = infeed_manager.dummy_x
@ -1313,7 +1350,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@ -1740,20 +1778,24 @@ class KerasTPUModel(models.Model):
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
self,
model_fn_lib.ModeKeys.TRAIN,
tpu_assignment=self._tpu_assignment)
return self.train_function
def _make_test_function(self):
if not self.test_function:
self.test_function = TPUFunction(
self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
return self.test_function
def _make_predict_function(self):
if not self.predict_function:
self.predict_function = TPUFunction(
self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
self,
model_fn_lib.ModeKeys.PREDICT,
tpu_assignment=self._tpu_assignment)
return self.predict_function
def _initialize_weights(self, cloned_model):
@ -1825,6 +1867,7 @@ class KerasTPUModel(models.Model):
self._session.close()
# pylint: disable=bad-continuation
def _validate_shapes(model):
"""Validate that all layers in `model` have constant shape."""
for layer in model.layers:
@ -1852,10 +1895,13 @@ Layer: %(layer)s
Input shape: %(input_shape)s
Output shape: %(output_shape)s
""" % {
'layer': layer,
'input_shape': layer.input_shape,
'output_shape': layer.output_shape
})
'layer': layer,
'input_shape': layer.input_shape,
'output_shape': layer.output_shape
})
# pylint: enable=bad-continuation
@experimental