Adds train_batch_size into TPUEstimator.

PiperOrigin-RevId: 161102679
This commit is contained in:
Jianwei Xie 2017-07-06 11:05:53 -07:00 committed by TensorFlower Gardener
parent ad0ba9e6ea
commit 2d50faa5d5

View File

@ -43,6 +43,7 @@ from tensorflow.python.training import training
_BATCH_SIZE_KEY = 'batch_size'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
def _tpu_job(run_config):
@ -51,6 +52,11 @@ def _tpu_job(run_config):
return None if run_config.master in ['', 'local'] else 'tpu_worker'
def _per_shard_batch_size(global_batch_size, run_config):
"""Returns the batch size for each shard."""
return global_batch_size // run_config.tpu_config.num_shards
class _SIGNAL(object):
"""Signal used to control the input thread of infeed."""
NEXT_BATCH = 1
@ -168,8 +174,11 @@ class TpuEstimator(estimator_lib.Estimator):
replicating inputs and models for each core, and returning to host
periodically to run hooks.
Note: TpuEstimator transforms a global batch size in params to a per-shard
batch size when calling the input_fn.
Note: For training (evaluate and predict support on TPU are not yet
implemented), TpuEstimator transforms a global batch size in params to a
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
specify `train_batch_size` in constructor, and then get the batch size for
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
"""
def __init__(self,
@ -177,23 +186,59 @@ class TpuEstimator(estimator_lib.Estimator):
model_dir=None,
config=None,
params=None,
use_tpu=True):
use_tpu=True,
train_batch_size=None):
"""Constructs an `TpuEstimator` instance.
Args:
model_fn: Model function as required by `Estimator`. For training, the
returned `EstimatorSpec` cannot have hooks as it is not supported in
`TPUEstimator`.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
`config` will be used if set. If both are set, they must be same. If
both are `None`, a temporary directory will be used.
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
params: An optional `dict` of hyper parameters that will be passed into
`input_fn` and `model_fn`. Keys are names of parameters, values are
basic python types. There are reserved keys for `TPUEstimator`,
including 'batch_size'.
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
applied to training. Evaluate and predict still happen on CPU.
train_batch_size: An int representing the global training batch size.
TpuEstimator transforms this global batch size to a per-shard batch
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
`config.tpu_config.num_shards`.
Raises:
ValueError: `params` has reserved keys already.
"""
if config is None or not isinstance(config, tpu_config.RunConfig):
raise ValueError(
'`config` must be provided with type `tpu_config.RunConfig`')
if use_tpu and params is not None and _BATCH_SIZE_KEY in params:
if not isinstance(params[_BATCH_SIZE_KEY], int):
raise ValueError(
'`{}` in params must be an int'.format(_BATCH_SIZE_KEY))
params = copy.deepcopy(params)
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
raise ValueError(
'{} are reserved keys but existed in params {}.'.format(
_RESERVED_PARAMS_KEYS, params))
if use_tpu:
if train_batch_size is None:
raise ValueError('`train_batch_size` cannot be `None`')
if not isinstance(train_batch_size, int):
raise ValueError('`train_batch_size` must be an int')
if train_batch_size < 1:
raise ValueError('`train_batch_size` must be positive')
# The specified batch size is the batch size for the entire computation.
# The input_fn is called per-shard, so we want to calculate the per-shard
# batch size and pass that.
if params[_BATCH_SIZE_KEY] % config.tpu_config.num_shards != 0:
# The input_fn and model_fn are called per-shard, so we want to calculate
# the per-shard batch size and pass that.
if train_batch_size % config.tpu_config.num_shards != 0:
raise ValueError(
'batch size {} must be divisible by number of shards {}'
.format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards))
.format(train_batch_size, config.tpu_config.num_shards))
if use_tpu:
# Verifies the model_fn signature according to Estimator framework.
@ -201,7 +246,7 @@ class TpuEstimator(estimator_lib.Estimator):
# We cannot store config and params in this constructor as parent
# constructor might change them, such as assigning a temp dir for
# config.model_dir.
model_function = wrapped_model_fn(model_fn)
model_function = wrapped_model_fn(model_fn, train_batch_size)
else:
model_function = model_fn
@ -210,7 +255,8 @@ class TpuEstimator(estimator_lib.Estimator):
model_dir=model_dir,
config=config,
params=params)
self.use_tpu = use_tpu
self._use_tpu = use_tpu
self._train_batch_size = train_batch_size
def _create_global_step(self, graph):
"""Creates a global step suitable for TPUs.
@ -252,9 +298,9 @@ class TpuEstimator(estimator_lib.Estimator):
labels - `Tensor` or dictionary of `Tensor` with labels.
Raises:
ValueError: if input_fn takes invalid arguments.
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
if not self.use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
@ -262,12 +308,16 @@ class TpuEstimator(estimator_lib.Estimator):
kwargs = {}
if 'params' in input_fn_args:
kwargs['params'] = self.params # a deep copy.
else:
raise ValueError('input_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params["batch_size"]'.format(input_fn))
if 'config' in input_fn_args:
kwargs['config'] = config
# Now for TPU training.
if 'params' in kwargs and _BATCH_SIZE_KEY in kwargs['params']:
kwargs['params'][_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size
job = _tpu_job(config)
def placement_function(index):
@ -287,8 +337,10 @@ class TpuEstimator(estimator_lib.Estimator):
labels.append(result[1])
else:
features.append(result)
if not labels or all(l is None for l in labels):
return _PerShardOutput(features), None
return _PerShardOutput(features), _PerShardOutput(labels)
@ -302,16 +354,22 @@ def _verify_estimator_spec(estimator_spec):
return estimator_spec
def _call_model_fn(model_fn, features, labels, mode, config, params):
def _call_model_fn(model_fn, features, labels, mode, config, params,
require_params=False):
"""Calls the model_fn with required parameters."""
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
kwargs = {}
if 'mode' in model_fn_args:
kwargs['mode'] = mode
if 'params' in model_fn_args:
kwargs['params'] = params
if 'config' in model_fn_args:
kwargs['config'] = config
if 'params' in model_fn_args:
kwargs['params'] = params
elif require_params:
raise ValueError(
'model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(model_fn))
return model_fn(features=features, labels=labels, **kwargs)
@ -321,7 +379,7 @@ def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
config = copy.deepcopy(config)
params = copy.deepcopy(params)
return _verify_estimator_spec(_call_model_fn(
model_fn, features, labels, mode, config, params))
model_fn, features, labels, mode, config, params, require_params=True))
def _call_model_fn_without_tpu(
@ -416,10 +474,10 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
return (dequeue_fn, enqueue_fn)
def wrapped_model_fn(model_fn):
def wrapped_model_fn(model_fn, train_batch_size):
"""Returns a new model_fn, which wraps the TPU support."""
def _model_fn(features, labels, mode, config, params=None):
def _model_fn(features, labels, mode, config, params):
"""model_fn."""
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
@ -427,9 +485,8 @@ def wrapped_model_fn(model_fn):
return _call_model_fn_without_tpu(
model_fn, features, labels, mode, config, params)
# Now for TPU training.
if params is not None and _BATCH_SIZE_KEY in params:
params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
# Now for TPU training. `params` is never `None`.
params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config)
assert isinstance(features, _PerShardOutput)
features = features.as_list()