Adds train_batch_size into TPUEstimator.
PiperOrigin-RevId: 161102679
This commit is contained in:
parent
ad0ba9e6ea
commit
2d50faa5d5
@ -43,6 +43,7 @@ from tensorflow.python.training import training
|
||||
|
||||
|
||||
_BATCH_SIZE_KEY = 'batch_size'
|
||||
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
|
||||
|
||||
|
||||
def _tpu_job(run_config):
|
||||
@ -51,6 +52,11 @@ def _tpu_job(run_config):
|
||||
return None if run_config.master in ['', 'local'] else 'tpu_worker'
|
||||
|
||||
|
||||
def _per_shard_batch_size(global_batch_size, run_config):
|
||||
"""Returns the batch size for each shard."""
|
||||
return global_batch_size // run_config.tpu_config.num_shards
|
||||
|
||||
|
||||
class _SIGNAL(object):
|
||||
"""Signal used to control the input thread of infeed."""
|
||||
NEXT_BATCH = 1
|
||||
@ -168,8 +174,11 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
replicating inputs and models for each core, and returning to host
|
||||
periodically to run hooks.
|
||||
|
||||
Note: TpuEstimator transforms a global batch size in params to a per-shard
|
||||
batch size when calling the input_fn.
|
||||
Note: For training (evaluate and predict support on TPU are not yet
|
||||
implemented), TpuEstimator transforms a global batch size in params to a
|
||||
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
|
||||
specify `train_batch_size` in constructor, and then get the batch size for
|
||||
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -177,23 +186,59 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
model_dir=None,
|
||||
config=None,
|
||||
params=None,
|
||||
use_tpu=True):
|
||||
use_tpu=True,
|
||||
train_batch_size=None):
|
||||
"""Constructs an `TpuEstimator` instance.
|
||||
|
||||
Args:
|
||||
model_fn: Model function as required by `Estimator`. For training, the
|
||||
returned `EstimatorSpec` cannot have hooks as it is not supported in
|
||||
`TPUEstimator`.
|
||||
model_dir: Directory to save model parameters, graph and etc. This can
|
||||
also be used to load checkpoints from the directory into a estimator to
|
||||
continue training a previously saved model. If `None`, the model_dir in
|
||||
`config` will be used if set. If both are set, they must be same. If
|
||||
both are `None`, a temporary directory will be used.
|
||||
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
|
||||
params: An optional `dict` of hyper parameters that will be passed into
|
||||
`input_fn` and `model_fn`. Keys are names of parameters, values are
|
||||
basic python types. There are reserved keys for `TPUEstimator`,
|
||||
including 'batch_size'.
|
||||
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
|
||||
applied to training. Evaluate and predict still happen on CPU.
|
||||
train_batch_size: An int representing the global training batch size.
|
||||
TpuEstimator transforms this global batch size to a per-shard batch
|
||||
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
|
||||
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
|
||||
`config.tpu_config.num_shards`.
|
||||
|
||||
Raises:
|
||||
ValueError: `params` has reserved keys already.
|
||||
"""
|
||||
if config is None or not isinstance(config, tpu_config.RunConfig):
|
||||
raise ValueError(
|
||||
'`config` must be provided with type `tpu_config.RunConfig`')
|
||||
|
||||
if use_tpu and params is not None and _BATCH_SIZE_KEY in params:
|
||||
if not isinstance(params[_BATCH_SIZE_KEY], int):
|
||||
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
|
||||
raise ValueError(
|
||||
'`{}` in params must be an int'.format(_BATCH_SIZE_KEY))
|
||||
params = copy.deepcopy(params)
|
||||
'{} are reserved keys but existed in params {}.'.format(
|
||||
_RESERVED_PARAMS_KEYS, params))
|
||||
|
||||
if use_tpu:
|
||||
if train_batch_size is None:
|
||||
raise ValueError('`train_batch_size` cannot be `None`')
|
||||
if not isinstance(train_batch_size, int):
|
||||
raise ValueError('`train_batch_size` must be an int')
|
||||
if train_batch_size < 1:
|
||||
raise ValueError('`train_batch_size` must be positive')
|
||||
|
||||
# The specified batch size is the batch size for the entire computation.
|
||||
# The input_fn is called per-shard, so we want to calculate the per-shard
|
||||
# batch size and pass that.
|
||||
if params[_BATCH_SIZE_KEY] % config.tpu_config.num_shards != 0:
|
||||
# The input_fn and model_fn are called per-shard, so we want to calculate
|
||||
# the per-shard batch size and pass that.
|
||||
if train_batch_size % config.tpu_config.num_shards != 0:
|
||||
raise ValueError(
|
||||
'batch size {} must be divisible by number of shards {}'
|
||||
.format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards))
|
||||
.format(train_batch_size, config.tpu_config.num_shards))
|
||||
|
||||
if use_tpu:
|
||||
# Verifies the model_fn signature according to Estimator framework.
|
||||
@ -201,7 +246,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
# We cannot store config and params in this constructor as parent
|
||||
# constructor might change them, such as assigning a temp dir for
|
||||
# config.model_dir.
|
||||
model_function = wrapped_model_fn(model_fn)
|
||||
model_function = wrapped_model_fn(model_fn, train_batch_size)
|
||||
else:
|
||||
model_function = model_fn
|
||||
|
||||
@ -210,7 +255,8 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
params=params)
|
||||
self.use_tpu = use_tpu
|
||||
self._use_tpu = use_tpu
|
||||
self._train_batch_size = train_batch_size
|
||||
|
||||
def _create_global_step(self, graph):
|
||||
"""Creates a global step suitable for TPUs.
|
||||
@ -252,9 +298,9 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||
|
||||
Raises:
|
||||
ValueError: if input_fn takes invalid arguments.
|
||||
ValueError: if input_fn takes invalid arguments or does not have `params`.
|
||||
"""
|
||||
if not self.use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
||||
|
||||
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
|
||||
@ -262,12 +308,16 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
kwargs = {}
|
||||
if 'params' in input_fn_args:
|
||||
kwargs['params'] = self.params # a deep copy.
|
||||
else:
|
||||
raise ValueError('input_fn ({}) does not include params argument, '
|
||||
'required by TPUEstimator to pass batch size as '
|
||||
'params["batch_size"]'.format(input_fn))
|
||||
if 'config' in input_fn_args:
|
||||
kwargs['config'] = config
|
||||
|
||||
# Now for TPU training.
|
||||
if 'params' in kwargs and _BATCH_SIZE_KEY in kwargs['params']:
|
||||
kwargs['params'][_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
|
||||
per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
|
||||
kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size
|
||||
|
||||
job = _tpu_job(config)
|
||||
def placement_function(index):
|
||||
@ -287,8 +337,10 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
labels.append(result[1])
|
||||
else:
|
||||
features.append(result)
|
||||
|
||||
if not labels or all(l is None for l in labels):
|
||||
return _PerShardOutput(features), None
|
||||
|
||||
return _PerShardOutput(features), _PerShardOutput(labels)
|
||||
|
||||
|
||||
@ -302,16 +354,22 @@ def _verify_estimator_spec(estimator_spec):
|
||||
return estimator_spec
|
||||
|
||||
|
||||
def _call_model_fn(model_fn, features, labels, mode, config, params):
|
||||
def _call_model_fn(model_fn, features, labels, mode, config, params,
|
||||
require_params=False):
|
||||
"""Calls the model_fn with required parameters."""
|
||||
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
|
||||
kwargs = {}
|
||||
if 'mode' in model_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
if 'params' in model_fn_args:
|
||||
kwargs['params'] = params
|
||||
if 'config' in model_fn_args:
|
||||
kwargs['config'] = config
|
||||
if 'params' in model_fn_args:
|
||||
kwargs['params'] = params
|
||||
elif require_params:
|
||||
raise ValueError(
|
||||
'model_fn ({}) does not include params argument, '
|
||||
'required by TPUEstimator to pass batch size as '
|
||||
'params[\'batch_size\']'.format(model_fn))
|
||||
return model_fn(features=features, labels=labels, **kwargs)
|
||||
|
||||
|
||||
@ -321,7 +379,7 @@ def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
|
||||
config = copy.deepcopy(config)
|
||||
params = copy.deepcopy(params)
|
||||
return _verify_estimator_spec(_call_model_fn(
|
||||
model_fn, features, labels, mode, config, params))
|
||||
model_fn, features, labels, mode, config, params, require_params=True))
|
||||
|
||||
|
||||
def _call_model_fn_without_tpu(
|
||||
@ -416,10 +474,10 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
|
||||
return (dequeue_fn, enqueue_fn)
|
||||
|
||||
|
||||
def wrapped_model_fn(model_fn):
|
||||
def wrapped_model_fn(model_fn, train_batch_size):
|
||||
"""Returns a new model_fn, which wraps the TPU support."""
|
||||
|
||||
def _model_fn(features, labels, mode, config, params=None):
|
||||
def _model_fn(features, labels, mode, config, params):
|
||||
"""model_fn."""
|
||||
|
||||
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
|
||||
@ -427,9 +485,8 @@ def wrapped_model_fn(model_fn):
|
||||
return _call_model_fn_without_tpu(
|
||||
model_fn, features, labels, mode, config, params)
|
||||
|
||||
# Now for TPU training.
|
||||
if params is not None and _BATCH_SIZE_KEY in params:
|
||||
params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
|
||||
# Now for TPU training. `params` is never `None`.
|
||||
params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config)
|
||||
|
||||
assert isinstance(features, _PerShardOutput)
|
||||
features = features.as_list()
|
||||
|
Loading…
Reference in New Issue
Block a user