Adds train_batch_size into TPUEstimator.
PiperOrigin-RevId: 161102679
This commit is contained in:
parent
ad0ba9e6ea
commit
2d50faa5d5
@ -43,6 +43,7 @@ from tensorflow.python.training import training
|
|||||||
|
|
||||||
|
|
||||||
_BATCH_SIZE_KEY = 'batch_size'
|
_BATCH_SIZE_KEY = 'batch_size'
|
||||||
|
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
|
||||||
|
|
||||||
|
|
||||||
def _tpu_job(run_config):
|
def _tpu_job(run_config):
|
||||||
@ -51,6 +52,11 @@ def _tpu_job(run_config):
|
|||||||
return None if run_config.master in ['', 'local'] else 'tpu_worker'
|
return None if run_config.master in ['', 'local'] else 'tpu_worker'
|
||||||
|
|
||||||
|
|
||||||
|
def _per_shard_batch_size(global_batch_size, run_config):
|
||||||
|
"""Returns the batch size for each shard."""
|
||||||
|
return global_batch_size // run_config.tpu_config.num_shards
|
||||||
|
|
||||||
|
|
||||||
class _SIGNAL(object):
|
class _SIGNAL(object):
|
||||||
"""Signal used to control the input thread of infeed."""
|
"""Signal used to control the input thread of infeed."""
|
||||||
NEXT_BATCH = 1
|
NEXT_BATCH = 1
|
||||||
@ -168,8 +174,11 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
replicating inputs and models for each core, and returning to host
|
replicating inputs and models for each core, and returning to host
|
||||||
periodically to run hooks.
|
periodically to run hooks.
|
||||||
|
|
||||||
Note: TpuEstimator transforms a global batch size in params to a per-shard
|
Note: For training (evaluate and predict support on TPU are not yet
|
||||||
batch size when calling the input_fn.
|
implemented), TpuEstimator transforms a global batch size in params to a
|
||||||
|
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
|
||||||
|
specify `train_batch_size` in constructor, and then get the batch size for
|
||||||
|
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -177,23 +186,59 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
model_dir=None,
|
model_dir=None,
|
||||||
config=None,
|
config=None,
|
||||||
params=None,
|
params=None,
|
||||||
use_tpu=True):
|
use_tpu=True,
|
||||||
|
train_batch_size=None):
|
||||||
|
"""Constructs an `TpuEstimator` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_fn: Model function as required by `Estimator`. For training, the
|
||||||
|
returned `EstimatorSpec` cannot have hooks as it is not supported in
|
||||||
|
`TPUEstimator`.
|
||||||
|
model_dir: Directory to save model parameters, graph and etc. This can
|
||||||
|
also be used to load checkpoints from the directory into a estimator to
|
||||||
|
continue training a previously saved model. If `None`, the model_dir in
|
||||||
|
`config` will be used if set. If both are set, they must be same. If
|
||||||
|
both are `None`, a temporary directory will be used.
|
||||||
|
config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
|
||||||
|
params: An optional `dict` of hyper parameters that will be passed into
|
||||||
|
`input_fn` and `model_fn`. Keys are names of parameters, values are
|
||||||
|
basic python types. There are reserved keys for `TPUEstimator`,
|
||||||
|
including 'batch_size'.
|
||||||
|
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
|
||||||
|
applied to training. Evaluate and predict still happen on CPU.
|
||||||
|
train_batch_size: An int representing the global training batch size.
|
||||||
|
TpuEstimator transforms this global batch size to a per-shard batch
|
||||||
|
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
|
||||||
|
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
|
||||||
|
`config.tpu_config.num_shards`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: `params` has reserved keys already.
|
||||||
|
"""
|
||||||
if config is None or not isinstance(config, tpu_config.RunConfig):
|
if config is None or not isinstance(config, tpu_config.RunConfig):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`config` must be provided with type `tpu_config.RunConfig`')
|
'`config` must be provided with type `tpu_config.RunConfig`')
|
||||||
|
|
||||||
if use_tpu and params is not None and _BATCH_SIZE_KEY in params:
|
if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
|
||||||
if not isinstance(params[_BATCH_SIZE_KEY], int):
|
raise ValueError(
|
||||||
raise ValueError(
|
'{} are reserved keys but existed in params {}.'.format(
|
||||||
'`{}` in params must be an int'.format(_BATCH_SIZE_KEY))
|
_RESERVED_PARAMS_KEYS, params))
|
||||||
params = copy.deepcopy(params)
|
|
||||||
|
if use_tpu:
|
||||||
|
if train_batch_size is None:
|
||||||
|
raise ValueError('`train_batch_size` cannot be `None`')
|
||||||
|
if not isinstance(train_batch_size, int):
|
||||||
|
raise ValueError('`train_batch_size` must be an int')
|
||||||
|
if train_batch_size < 1:
|
||||||
|
raise ValueError('`train_batch_size` must be positive')
|
||||||
|
|
||||||
# The specified batch size is the batch size for the entire computation.
|
# The specified batch size is the batch size for the entire computation.
|
||||||
# The input_fn is called per-shard, so we want to calculate the per-shard
|
# The input_fn and model_fn are called per-shard, so we want to calculate
|
||||||
# batch size and pass that.
|
# the per-shard batch size and pass that.
|
||||||
if params[_BATCH_SIZE_KEY] % config.tpu_config.num_shards != 0:
|
if train_batch_size % config.tpu_config.num_shards != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'batch size {} must be divisible by number of shards {}'
|
'batch size {} must be divisible by number of shards {}'
|
||||||
.format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards))
|
.format(train_batch_size, config.tpu_config.num_shards))
|
||||||
|
|
||||||
if use_tpu:
|
if use_tpu:
|
||||||
# Verifies the model_fn signature according to Estimator framework.
|
# Verifies the model_fn signature according to Estimator framework.
|
||||||
@ -201,7 +246,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
# We cannot store config and params in this constructor as parent
|
# We cannot store config and params in this constructor as parent
|
||||||
# constructor might change them, such as assigning a temp dir for
|
# constructor might change them, such as assigning a temp dir for
|
||||||
# config.model_dir.
|
# config.model_dir.
|
||||||
model_function = wrapped_model_fn(model_fn)
|
model_function = wrapped_model_fn(model_fn, train_batch_size)
|
||||||
else:
|
else:
|
||||||
model_function = model_fn
|
model_function = model_fn
|
||||||
|
|
||||||
@ -210,7 +255,8 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
params=params)
|
params=params)
|
||||||
self.use_tpu = use_tpu
|
self._use_tpu = use_tpu
|
||||||
|
self._train_batch_size = train_batch_size
|
||||||
|
|
||||||
def _create_global_step(self, graph):
|
def _create_global_step(self, graph):
|
||||||
"""Creates a global step suitable for TPUs.
|
"""Creates a global step suitable for TPUs.
|
||||||
@ -252,9 +298,9 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
labels - `Tensor` or dictionary of `Tensor` with labels.
|
labels - `Tensor` or dictionary of `Tensor` with labels.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if input_fn takes invalid arguments.
|
ValueError: if input_fn takes invalid arguments or does not have `params`.
|
||||||
"""
|
"""
|
||||||
if not self.use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||||
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
||||||
|
|
||||||
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
|
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
|
||||||
@ -262,12 +308,16 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
kwargs = {}
|
kwargs = {}
|
||||||
if 'params' in input_fn_args:
|
if 'params' in input_fn_args:
|
||||||
kwargs['params'] = self.params # a deep copy.
|
kwargs['params'] = self.params # a deep copy.
|
||||||
|
else:
|
||||||
|
raise ValueError('input_fn ({}) does not include params argument, '
|
||||||
|
'required by TPUEstimator to pass batch size as '
|
||||||
|
'params["batch_size"]'.format(input_fn))
|
||||||
if 'config' in input_fn_args:
|
if 'config' in input_fn_args:
|
||||||
kwargs['config'] = config
|
kwargs['config'] = config
|
||||||
|
|
||||||
# Now for TPU training.
|
# Now for TPU training.
|
||||||
if 'params' in kwargs and _BATCH_SIZE_KEY in kwargs['params']:
|
per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config)
|
||||||
kwargs['params'][_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
|
kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size
|
||||||
|
|
||||||
job = _tpu_job(config)
|
job = _tpu_job(config)
|
||||||
def placement_function(index):
|
def placement_function(index):
|
||||||
@ -287,8 +337,10 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
labels.append(result[1])
|
labels.append(result[1])
|
||||||
else:
|
else:
|
||||||
features.append(result)
|
features.append(result)
|
||||||
|
|
||||||
if not labels or all(l is None for l in labels):
|
if not labels or all(l is None for l in labels):
|
||||||
return _PerShardOutput(features), None
|
return _PerShardOutput(features), None
|
||||||
|
|
||||||
return _PerShardOutput(features), _PerShardOutput(labels)
|
return _PerShardOutput(features), _PerShardOutput(labels)
|
||||||
|
|
||||||
|
|
||||||
@ -302,16 +354,22 @@ def _verify_estimator_spec(estimator_spec):
|
|||||||
return estimator_spec
|
return estimator_spec
|
||||||
|
|
||||||
|
|
||||||
def _call_model_fn(model_fn, features, labels, mode, config, params):
|
def _call_model_fn(model_fn, features, labels, mode, config, params,
|
||||||
|
require_params=False):
|
||||||
"""Calls the model_fn with required parameters."""
|
"""Calls the model_fn with required parameters."""
|
||||||
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
|
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if 'mode' in model_fn_args:
|
if 'mode' in model_fn_args:
|
||||||
kwargs['mode'] = mode
|
kwargs['mode'] = mode
|
||||||
if 'params' in model_fn_args:
|
|
||||||
kwargs['params'] = params
|
|
||||||
if 'config' in model_fn_args:
|
if 'config' in model_fn_args:
|
||||||
kwargs['config'] = config
|
kwargs['config'] = config
|
||||||
|
if 'params' in model_fn_args:
|
||||||
|
kwargs['params'] = params
|
||||||
|
elif require_params:
|
||||||
|
raise ValueError(
|
||||||
|
'model_fn ({}) does not include params argument, '
|
||||||
|
'required by TPUEstimator to pass batch size as '
|
||||||
|
'params[\'batch_size\']'.format(model_fn))
|
||||||
return model_fn(features=features, labels=labels, **kwargs)
|
return model_fn(features=features, labels=labels, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -321,7 +379,7 @@ def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
|
|||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
params = copy.deepcopy(params)
|
params = copy.deepcopy(params)
|
||||||
return _verify_estimator_spec(_call_model_fn(
|
return _verify_estimator_spec(_call_model_fn(
|
||||||
model_fn, features, labels, mode, config, params))
|
model_fn, features, labels, mode, config, params, require_params=True))
|
||||||
|
|
||||||
|
|
||||||
def _call_model_fn_without_tpu(
|
def _call_model_fn_without_tpu(
|
||||||
@ -416,10 +474,10 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels):
|
|||||||
return (dequeue_fn, enqueue_fn)
|
return (dequeue_fn, enqueue_fn)
|
||||||
|
|
||||||
|
|
||||||
def wrapped_model_fn(model_fn):
|
def wrapped_model_fn(model_fn, train_batch_size):
|
||||||
"""Returns a new model_fn, which wraps the TPU support."""
|
"""Returns a new model_fn, which wraps the TPU support."""
|
||||||
|
|
||||||
def _model_fn(features, labels, mode, config, params=None):
|
def _model_fn(features, labels, mode, config, params):
|
||||||
"""model_fn."""
|
"""model_fn."""
|
||||||
|
|
||||||
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
|
# TODO(jhseu): Move to EVAL and PREDICT to TPU.
|
||||||
@ -427,9 +485,8 @@ def wrapped_model_fn(model_fn):
|
|||||||
return _call_model_fn_without_tpu(
|
return _call_model_fn_without_tpu(
|
||||||
model_fn, features, labels, mode, config, params)
|
model_fn, features, labels, mode, config, params)
|
||||||
|
|
||||||
# Now for TPU training.
|
# Now for TPU training. `params` is never `None`.
|
||||||
if params is not None and _BATCH_SIZE_KEY in params:
|
params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config)
|
||||||
params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards
|
|
||||||
|
|
||||||
assert isinstance(features, _PerShardOutput)
|
assert isinstance(features, _PerShardOutput)
|
||||||
features = features.as_list()
|
features = features.as_list()
|
||||||
|
Loading…
Reference in New Issue
Block a user