Adds train_batch_size into TPUEstimator.

PiperOrigin-RevId: 161102679
This commit is contained in:
Jianwei Xie 2017-07-06 11:05:53 -07:00 committed by TensorFlower Gardener
parent ad0ba9e6ea
commit 2d50faa5d5

View File

@ -43,6 +43,7 @@ from tensorflow.python.training import training
_BATCH_SIZE_KEY = 'batch_size' _BATCH_SIZE_KEY = 'batch_size'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
def _tpu_job(run_config): def _tpu_job(run_config):
@ -51,6 +52,11 @@ def _tpu_job(run_config):
return None if run_config.master in ['', 'local'] else 'tpu_worker' return None if run_config.master in ['', 'local'] else 'tpu_worker'
def _per_shard_batch_size(global_batch_size, run_config):
"""Returns the batch size for each shard."""
return global_batch_size // run_config.tpu_config.num_shards
class _SIGNAL(object): class _SIGNAL(object):
"""Signal used to control the input thread of infeed.""" """Signal used to control the input thread of infeed."""
NEXT_BATCH = 1 NEXT_BATCH = 1
@ -168,8 +174,11 @@ class TpuEstimator(estimator_lib.Estimator):
replicating inputs and models for each core, and returning to host replicating inputs and models for each core, and returning to host
periodically to run hooks. periodically to run hooks.
Note: TpuEstimator transforms a global batch size in params to a per-shard Note: For training (evaluate and predict support on TPU are not yet
batch size when calling the input_fn. implemented), TpuEstimator transforms a global batch size in params to a
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
specify `train_batch_size` in constructor, and then get the batch size for
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
""" """
def __init__(self, def __init__(self,
@ -177,23 +186,59 @@ class TpuEstimator(estimator_lib.Estimator):
model_dir=None, model_dir=None,
config=None, config=None,
params=None, params=None,
use_tpu=True): use_tpu=True,
train_batch_size=None):
"""Constructs an `TpuEstimator` instance.
Args:
model_fn: Model function as required by `Estimator`. For training, the
returned `EstimatorSpec` cannot have hooks as it is not supported in
`TPUEstimator`.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same. If
both are `None`, a temporary directory will be used.
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
params: An optional `dict` of hyper parameters that will be passed into
`input_fn` and `model_fn`. Keys are names of parameters, values are
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
applied to training. Evaluate and predict still happen on CPU.
train_batch_size: An int representing the global training batch size.
TpuEstimator transforms this global batch size to a per-shard batch
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
`config.tpu_config.num_shards`.
Raises:
ValueError: `params` has reserved keys already.
"""
if config is None or not isinstance(config, tpu_config.RunConfig): if config is None or not isinstance(config, tpu_config.RunConfig):
raise ValueError( raise ValueError(
'`config` must be provided with type `tpu_config.RunConfig`') '`config` must be provided with type `tpu_config.RunConfig`')
if use_tpu and params is not None and _BATCH_SIZE_KEY in params: if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
if not isinstance(params[_BATCH_SIZE_KEY], int):
raise ValueError( raise ValueError(
'`{}` in params must be an int'.format(_BATCH_SIZE_KEY)) '{} are reserved keys but existed in params {}.'.format(
params = copy.deepcopy(params) _RESERVED_PARAMS_KEYS, params))
if use_tpu:
if train_batch_size is None:
raise ValueError('`train_batch_size` cannot be `None`')
if not isinstance(train_batch_size, int):
raise ValueError('`train_batch_size` must be an int')
if train_batch_size < 1:
raise ValueError('`train_batch_size` must be positive')
# The specified batch size is the batch size for the entire computation. # The specified batch size is the batch size for the entire computation.
# The input_fn is called per-shard, so we want to calculate the per-shard # The input_fn and model_fn are called per-shard, so we want to calculate
# batch size and pass that. # the per-shard batch size and pass that.
if params[_BATCH_SIZE_KEY] % config.tpu_config.num_shards != 0: if train_batch_size % config.tpu_config.num_shards != 0:
raise ValueError( raise ValueError(
'batch size {} must be divisible by number of shards {}' 'batch size {} must be divisible by number of shards {}'
.format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards)) .format(train_batch_size, config.tpu_config.num_shards))
if use_tpu: if use_tpu:
# Verifies the model_fn signature according to Estimator framework. # Verifies the model_fn signature according to Estimator framework.
@ -201,7 +246,7 @@ class TpuEstimator(estimator_lib.Estimator):
# We cannot store config and params in this constructor as parent # We cannot store config and params in this constructor as parent
# constructor might change them, such as assigning a temp dir for # constructor might change them, such as assigning a temp dir for
# config.model_dir. # config.model_dir.
model_function = wrapped_model_fn(model_fn) model_function = wrapped_model_fn(model_fn, train_batch_size)
else: else:
model_function = model_fn model_function = model_fn
@ -210,7 +255,8 @@ class TpuEstimator(estimator_lib.Estimator):
model_dir=model_dir, model_dir=model_dir,
config=config, config=config,
params=params) params=params)
self.use_tpu = use_tpu self._use_tpu = use_tpu
self._train_batch_size = train_batch_size
def _create_global_step(self, graph): def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs. """Creates a global step suitable for TPUs.
@ -252,9 +298,9 @@ class TpuEstimator(estimator_lib.Estimator):
labels - `Tensor` or dictionary of `Tensor` with labels. labels - `Tensor` or dictionary of `Tensor` with labels.
Raises: Raises:
ValueError: if input_fn takes invalid arguments. ValueError: if input_fn takes invalid arguments or does not have `params`.
""" """
if not self.use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
return super(TpuEstimator, self)._call_input_fn(input_fn, mode) return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
@ -262,12 +308,16 @@ class TpuEstimator(estimator_lib.Estimator):
kwargs = {} kwargs = {}
if 'params' in input_fn_args: if 'params' in input_fn_args:
kwargs['params'] = self.params # a deep copy. kwargs['params'] = self.params # a deep copy.
else:
raise ValueError('input_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params["batch_size"]'.format(input_fn))
if 'config' in input_fn_args: if 'config' in input_fn_args:
kwargs['config'] = config kwargs['config'] = config
# Now for TPU training. # Now for TPU training.
if 'params' in kwargs and _BATCH_SIZE_KEY in kwargs['params']: per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
kwargs['params'][_BATCH_SIZE_KEY] //= config.tpu_config.num_shards kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size
job = _tpu_job(config) job = _tpu_job(config)
def placement_function(index): def placement_function(index):
@ -287,8 +337,10 @@ class TpuEstimator(estimator_lib.Estimator):
labels.append(result[1]) labels.append(result[1])
else: else:
features.append(result) features.append(result)
if not labels or all(l is None for l in labels): if not labels or all(l is None for l in labels):
return _PerShardOutput(features), None return _PerShardOutput(features), None
return _PerShardOutput(features), _PerShardOutput(labels) return _PerShardOutput(features), _PerShardOutput(labels)
@ -302,16 +354,22 @@ def _verify_estimator_spec(estimator_spec):
return estimator_spec return estimator_spec
def _call_model_fn(model_fn, features, labels, mode, config, params): def _call_model_fn(model_fn, features, labels, mode, config, params,
require_params=False):
"""Calls the model_fn with required parameters.""" """Calls the model_fn with required parameters."""
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
kwargs = {} kwargs = {}
if 'mode' in model_fn_args: if 'mode' in model_fn_args:
kwargs['mode'] = mode kwargs['mode'] = mode
if 'params' in model_fn_args:
kwargs['params'] = params
if 'config' in model_fn_args: if 'config' in model_fn_args:
kwargs['config'] = config kwargs['config'] = config
if 'params' in model_fn_args:
kwargs['params'] = params
elif require_params:
raise ValueError(
'model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(model_fn))
return model_fn(features=features, labels=labels, **kwargs) return model_fn(features=features, labels=labels, **kwargs)
@ -321,7 +379,7 @@ def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
config = copy.deepcopy(config) config = copy.deepcopy(config)
params = copy.deepcopy(params) params = copy.deepcopy(params)
return _verify_estimator_spec(_call_model_fn( return _verify_estimator_spec(_call_model_fn(
model_fn, features, labels, mode, config, params)) model_fn, features, labels, mode, config, params, require_params=True))
def _call_model_fn_without_tpu( def _call_model_fn_without_tpu(
@ -416,10 +474,10 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
return (dequeue_fn, enqueue_fn) return (dequeue_fn, enqueue_fn)
def wrapped_model_fn(model_fn): def wrapped_model_fn(model_fn, train_batch_size):
"""Returns a new model_fn, which wraps the TPU support.""" """Returns a new model_fn, which wraps the TPU support."""
def _model_fn(features, labels, mode, config, params=None): def _model_fn(features, labels, mode, config, params):
"""model_fn.""" """model_fn."""
# TODO(jhseu): Move to EVAL and PREDICT to TPU. # TODO(jhseu): Move to EVAL and PREDICT to TPU.
@ -427,9 +485,8 @@ def wrapped_model_fn(model_fn):
return _call_model_fn_without_tpu( return _call_model_fn_without_tpu(
model_fn, features, labels, mode, config, params) model_fn, features, labels, mode, config, params)
# Now for TPU training. # Now for TPU training. `params` is never `None`.
if params is not None and _BATCH_SIZE_KEY in params: params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config)
params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
assert isinstance(features, _PerShardOutput) assert isinstance(features, _PerShardOutput)
features = features.as_list() features = features.as_list()