Automatically use single core for stateful RNN in Keras TPU.
PiperOrigin-RevId: 211532963
This commit is contained in:
parent
ee24255e3d
commit
4fbc4e5b98
@ -170,11 +170,41 @@ class TPUDistributionStrategy(object):
|
||||
worker_re = re.compile('/job:([^/]+)')
|
||||
for device in metadata.devices:
|
||||
if 'TPU:0' in device.name:
|
||||
self.worker_name = worker_re.search(device.name).group(1)
|
||||
self._worker_name = worker_re.search(device.name).group(1)
|
||||
break
|
||||
|
||||
def _make_assignment_for_model(self, cpu_model):
|
||||
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
|
||||
num_cores = self._num_cores
|
||||
if num_cores > 1 and cpu_model.stateful:
|
||||
logging.warning(
|
||||
'Model replication does not currently support stateful models. '
|
||||
'Degrading to a single core.')
|
||||
num_cores = 1
|
||||
|
||||
return TPUAssignment(
|
||||
worker_name=self._worker_name, num_cores=num_cores)
|
||||
|
||||
|
||||
class TPUAssignment(object):
|
||||
"""This is object holding TPU resources assignment for the concrete model.
|
||||
|
||||
`TPUDistributionStrategy` is responsible to create the instance of
|
||||
`TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
|
||||
model and input batch sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, worker_name, num_cores):
|
||||
self._worker_name = worker_name
|
||||
self._num_cores = num_cores
|
||||
|
||||
@property
|
||||
def worker_name(self):
|
||||
return self._worker_name
|
||||
|
||||
@property
|
||||
def num_towers(self):
|
||||
# TODO(xiejw): Support automatically assign num_cores based on inputs.
|
||||
return self._num_cores
|
||||
|
||||
|
||||
@ -495,8 +525,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
|
||||
infeed_dict[tensor] = value
|
||||
return infeed_dict
|
||||
|
||||
def __init__(self, distribution_strategy):
|
||||
self._strategy = distribution_strategy
|
||||
def __init__(self, tpu_assignment):
|
||||
self._tpu_assignment = tpu_assignment
|
||||
|
||||
def _split_tensors(self, inputs):
|
||||
"""Split input data across shards.
|
||||
@ -509,16 +539,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
|
||||
Returns:
|
||||
List of lists containing the input to feed to each TPU shard.
|
||||
"""
|
||||
if self._strategy.num_towers == 1:
|
||||
if self._tpu_assignment.num_towers == 1:
|
||||
return [inputs]
|
||||
|
||||
batch_size = inputs[0].shape[0]
|
||||
assert batch_size % self._strategy.num_towers == 0, (
|
||||
'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
|
||||
(batch_size, self._strategy.num_towers))
|
||||
shard_size = batch_size // self._strategy.num_towers
|
||||
assert batch_size % self._tpu_assignment.num_towers == 0, (
|
||||
'batch_size must be divisible by the number of TPU cores in use (%s '
|
||||
'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
|
||||
shard_size = batch_size // self._tpu_assignment.num_towers
|
||||
input_list = []
|
||||
for index in range(self._strategy.num_towers):
|
||||
for index in range(self._tpu_assignment.num_towers):
|
||||
shard_inputs = [
|
||||
x[index * shard_size:(index + 1) * shard_size] for x in inputs
|
||||
]
|
||||
@ -533,8 +563,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
|
||||
infeed_op = []
|
||||
shard_infeed_tensors = []
|
||||
|
||||
for shard_id in range(self._strategy.num_towers):
|
||||
with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
|
||||
for shard_id in range(self._tpu_assignment.num_towers):
|
||||
with ops.device(
|
||||
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
|
||||
infeed_tensors = []
|
||||
with ops.device('/device:TPU:%d' % shard_id):
|
||||
for spec in input_specs:
|
||||
@ -573,30 +604,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
|
||||
# TODO(saeta): Verify tpu_model_op is as expected!
|
||||
return {}
|
||||
|
||||
def __init__(self, dataset, distribution_strategy, tpu_session):
|
||||
# pylint: disable=redefined-outer-name
|
||||
def __init__(self, dataset, tpu_assignment, tpu_session):
|
||||
"""Constructs a TPUDatasetInfeedManager.
|
||||
|
||||
Must be called within a `KerasTPUModel.tpu_session` context!
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` to infeed.
|
||||
distribution_strategy: The `TPUDistributionStrategy` used to configure the
|
||||
tpu_assignment: The `TPUAssignment` used to configure the
|
||||
Keras TPU model.
|
||||
tpu_session: The `tf.Session` object used for running the TPU model.
|
||||
"""
|
||||
self._verify_dataset_shape(dataset)
|
||||
self._dataset = dataset
|
||||
self._strategy = distribution_strategy
|
||||
self._tpu_assignment = tpu_assignment
|
||||
dummy_x_shape = dataset.output_shapes[0].as_list()
|
||||
dummy_x_shape[0] *= distribution_strategy.num_towers
|
||||
dummy_x_shape[0] *= tpu_assignment.num_towers
|
||||
dummy_y_shape = dataset.output_shapes[1].as_list()
|
||||
dummy_y_shape[0] *= distribution_strategy.num_towers
|
||||
dummy_y_shape[0] *= tpu_assignment.num_towers
|
||||
self._iterator = dataset.make_initializable_iterator()
|
||||
tpu_session.run(self._iterator.initializer)
|
||||
|
||||
self._get_next_ops = []
|
||||
ctrl_deps = []
|
||||
for i in range(distribution_strategy.num_towers):
|
||||
for i in range(tpu_assignment.num_towers):
|
||||
with ops.control_dependencies(ctrl_deps): # Ensure deterministic
|
||||
# TODO(saeta): Ensure correct placement!
|
||||
get_next_op = self._iterator.get_next()
|
||||
@ -676,10 +708,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
|
||||
|
||||
def build_infeed_from_input_specs(self, input_specs, execution_mode):
|
||||
shard_infeed_tensors = self._get_next_ops
|
||||
assert len(shard_infeed_tensors) == self._strategy.num_towers
|
||||
assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
|
||||
infeed_ops = []
|
||||
for shard_id in range(self._strategy.num_towers):
|
||||
with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
|
||||
for shard_id in range(self._tpu_assignment.num_towers):
|
||||
with ops.device(
|
||||
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
|
||||
infeed_ops.append(
|
||||
tpu_ops.infeed_enqueue_tuple(
|
||||
shard_infeed_tensors[shard_id],
|
||||
@ -702,10 +735,10 @@ class TPUFunction(object):
|
||||
instead of being injected as `feed_dict` items or fetches.
|
||||
"""
|
||||
|
||||
def __init__(self, model, execution_mode, strategy):
|
||||
def __init__(self, model, execution_mode, tpu_assignment):
|
||||
self.model = model
|
||||
self.execution_mode = execution_mode
|
||||
self._strategy = strategy
|
||||
self._tpu_assignment = tpu_assignment
|
||||
self._compilation_cache = {}
|
||||
self._cloned_model = None
|
||||
|
||||
@ -757,7 +790,8 @@ class TPUFunction(object):
|
||||
# Clone our CPU model, running within the TPU device context.
|
||||
with TPURewriteContext(tpu_input_map):
|
||||
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
|
||||
with keras_tpu_variables.replicated_scope(self._strategy.num_towers):
|
||||
with keras_tpu_variables.replicated_scope(
|
||||
self._tpu_assignment.num_towers):
|
||||
self._cloned_model = models.clone_model(self.model)
|
||||
|
||||
# Create a copy of the optimizer for this graph.
|
||||
@ -827,7 +861,7 @@ class TPUFunction(object):
|
||||
# `execute op` replicates `_model_fn` `num_replicas` times, with each shard
|
||||
# running on a different logical core.
|
||||
compile_op, execute_op = tpu.split_compile_and_replicate(
|
||||
_model_fn, inputs=[[]] * self._strategy.num_towers)
|
||||
_model_fn, inputs=[[]] * self._tpu_assignment.num_towers)
|
||||
|
||||
# Generate CPU side operations to enqueue features/labels and dequeue
|
||||
# outputs from the model call.
|
||||
@ -835,8 +869,9 @@ class TPUFunction(object):
|
||||
input_specs, self.execution_mode)
|
||||
# Build output ops.
|
||||
outfeed_op = []
|
||||
for shard_id in range(self._strategy.num_towers):
|
||||
with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
|
||||
for shard_id in range(self._tpu_assignment.num_towers):
|
||||
with ops.device(
|
||||
'/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
|
||||
outfeed_op.extend(
|
||||
tpu_ops.outfeed_dequeue_tuple(
|
||||
dtypes=[spec.dtype for spec in self._outfeed_spec],
|
||||
@ -886,7 +921,7 @@ class TPUFunction(object):
|
||||
for x, mgr in self.model._numpy_to_infeed_manager_list:
|
||||
if inputs[0] is x:
|
||||
return mgr
|
||||
return TPUNumpyInfeedManager(self.model._strategy)
|
||||
return TPUNumpyInfeedManager(self.model._tpu_assignment)
|
||||
|
||||
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
|
||||
"""Looks up the corresponding `TPUModelOp` for a given `input_specs`.
|
||||
@ -958,7 +993,7 @@ class TPUFunction(object):
|
||||
outputs = [[]] * len(self._outfeed_spec)
|
||||
outputs_per_replica = len(self._outfeed_spec)
|
||||
|
||||
for i in range(self._strategy.num_towers):
|
||||
for i in range(self._tpu_assignment.num_towers):
|
||||
output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
|
||||
outputs_per_replica]
|
||||
for j in range(outputs_per_replica):
|
||||
@ -967,7 +1002,7 @@ class TPUFunction(object):
|
||||
return [np.concatenate(group) for group in outputs]
|
||||
else:
|
||||
return outfeed_outputs[:len(outfeed_outputs) //
|
||||
self._strategy.num_towers]
|
||||
self._tpu_assignment.num_towers]
|
||||
|
||||
def __call__(self, inputs):
|
||||
"""__call__ executes the function on the computational hardware.
|
||||
@ -1119,11 +1154,11 @@ class KerasTPUModel(models.Model):
|
||||
self.predict_function = None
|
||||
self.test_function = None
|
||||
self.train_function = None
|
||||
self._strategy = strategy
|
||||
|
||||
cluster_resolver = self._strategy._tpu_cluster_resolver
|
||||
cluster_resolver = strategy._tpu_cluster_resolver
|
||||
self._tpu_name_or_address = cluster_resolver.get_master()
|
||||
self._cpu_model = cpu_model
|
||||
self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
|
||||
self._tpu_model = None
|
||||
self._tpu_weights_initialized = False
|
||||
|
||||
@ -1146,7 +1181,7 @@ class KerasTPUModel(models.Model):
|
||||
return {
|
||||
'cpu_model': self._cpu_model,
|
||||
'tpu_name_or_address': self._tpu_name_or_address,
|
||||
'strategy': self._strategy,
|
||||
'tpu_assignment': self._tpu_assignment,
|
||||
}
|
||||
|
||||
def compile(self,
|
||||
@ -1207,7 +1242,7 @@ class KerasTPUModel(models.Model):
|
||||
'/keras')
|
||||
if callable(x):
|
||||
with self.tpu_session() as sess,\
|
||||
ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name):
|
||||
ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
|
||||
dataset = x()
|
||||
if steps_per_epoch is None:
|
||||
raise ValueError('When using tf.data as input to a model, you '
|
||||
@ -1215,7 +1250,8 @@ class KerasTPUModel(models.Model):
|
||||
if y is not None:
|
||||
raise ValueError('When using tf.data as input to a model, y must be '
|
||||
'None')
|
||||
infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
|
||||
infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
|
||||
sess)
|
||||
# Use dummy numpy inputs for the rest of Keras' shape checking. We
|
||||
# intercept them when building the model.
|
||||
x = infeed_manager.dummy_x
|
||||
@ -1236,7 +1272,8 @@ class KerasTPUModel(models.Model):
|
||||
if validation_steps is None:
|
||||
raise ValueError('When using tf.data as validation for a model, you '
|
||||
'should specify the validation_steps argument.')
|
||||
infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
|
||||
infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
|
||||
sess)
|
||||
# Use dummy numpy inputs for the rest of Keras' shape checking. We
|
||||
# intercept them when building the model.
|
||||
val_x = infeed_manager.dummy_x
|
||||
@ -1313,7 +1350,8 @@ class KerasTPUModel(models.Model):
|
||||
if y is not None:
|
||||
raise ValueError('When using tf.data as input to a model, y must be '
|
||||
'None')
|
||||
infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
|
||||
infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
|
||||
sess)
|
||||
# Use dummy numpy inputs for the rest of Keras' shape checking. We
|
||||
# intercept them when building the model.
|
||||
x = infeed_manager.dummy_x
|
||||
@ -1740,20 +1778,24 @@ class KerasTPUModel(models.Model):
|
||||
def _make_train_function(self):
|
||||
if not self.train_function:
|
||||
self.train_function = TPUFunction(
|
||||
self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
|
||||
self,
|
||||
model_fn_lib.ModeKeys.TRAIN,
|
||||
tpu_assignment=self._tpu_assignment)
|
||||
|
||||
return self.train_function
|
||||
|
||||
def _make_test_function(self):
|
||||
if not self.test_function:
|
||||
self.test_function = TPUFunction(
|
||||
self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
|
||||
self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
|
||||
return self.test_function
|
||||
|
||||
def _make_predict_function(self):
|
||||
if not self.predict_function:
|
||||
self.predict_function = TPUFunction(
|
||||
self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
|
||||
self,
|
||||
model_fn_lib.ModeKeys.PREDICT,
|
||||
tpu_assignment=self._tpu_assignment)
|
||||
return self.predict_function
|
||||
|
||||
def _initialize_weights(self, cloned_model):
|
||||
@ -1825,6 +1867,7 @@ class KerasTPUModel(models.Model):
|
||||
self._session.close()
|
||||
|
||||
|
||||
# pylint: disable=bad-continuation
|
||||
def _validate_shapes(model):
|
||||
"""Validate that all layers in `model` have constant shape."""
|
||||
for layer in model.layers:
|
||||
@ -1852,10 +1895,13 @@ Layer: %(layer)s
|
||||
Input shape: %(input_shape)s
|
||||
Output shape: %(output_shape)s
|
||||
""" % {
|
||||
'layer': layer,
|
||||
'input_shape': layer.input_shape,
|
||||
'output_shape': layer.output_shape
|
||||
})
|
||||
'layer': layer,
|
||||
'input_shape': layer.input_shape,
|
||||
'output_shape': layer.output_shape
|
||||
})
|
||||
|
||||
|
||||
# pylint: enable=bad-continuation
|
||||
|
||||
|
||||
@experimental
|
||||
|
Loading…
x
Reference in New Issue
Block a user