diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 168726a6b3e..b9da8dc35ab 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -43,6 +43,7 @@ from tensorflow.python.training import training _BATCH_SIZE_KEY = 'batch_size' +_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] def _tpu_job(run_config): @@ -51,6 +52,11 @@ def _tpu_job(run_config): return None if run_config.master in ['', 'local'] else 'tpu_worker' +def _per_shard_batch_size(global_batch_size, run_config): + """Returns the batch size for each shard.""" + return global_batch_size // run_config.tpu_config.num_shards + + class _SIGNAL(object): """Signal used to control the input thread of infeed.""" NEXT_BATCH = 1 @@ -168,8 +174,11 @@ class TpuEstimator(estimator_lib.Estimator): replicating inputs and models for each core, and returning to host periodically to run hooks. - Note: TpuEstimator transforms a global batch size in params to a per-shard - batch size when calling the input_fn. + Note: For training (evaluate and predict support on TPU are not yet + implemented), TpuEstimator transforms a global batch size in params to a + per-shard batch size when calling the `input_fn` and `model_fn`. Users should + specify `train_batch_size` in constructor, and then get the batch size for + each shard in `input_fn` and `model_fn` by `params['batch_size']`. """ def __init__(self, @@ -177,23 +186,59 @@ class TpuEstimator(estimator_lib.Estimator): model_dir=None, config=None, params=None, - use_tpu=True): + use_tpu=True, + train_batch_size=None): + """Constructs an `TpuEstimator` instance. + + Args: + model_fn: Model function as required by `Estimator`. For training, the + returned `EstimatorSpec` cannot have hooks as it is not supported in + `TPUEstimator`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. If `None`, the model_dir in + `config` will be used if set. If both are set, they must be same. If + both are `None`, a temporary directory will be used. + config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. + params: An optional `dict` of hyper parameters that will be passed into + `input_fn` and `model_fn`. Keys are names of parameters, values are + basic python types. There are reserved keys for `TPUEstimator`, + including 'batch_size'. + use_tpu: A bool indicating whether TPU support is enabled. Currently, only + applied to training. Evaluate and predict still happen on CPU. + train_batch_size: An int representing the global training batch size. + TpuEstimator transforms this global batch size to a per-shard batch + size, as params['batch_size'], when calling `input_fn` and `model_fn`. + Cannot be `None` if `use_tpu` is `True`. Must be divisible by + `config.tpu_config.num_shards`. + + Raises: + ValueError: `params` has reserved keys already. + """ if config is None or not isinstance(config, tpu_config.RunConfig): raise ValueError( '`config` must be provided with type `tpu_config.RunConfig`') - if use_tpu and params is not None and _BATCH_SIZE_KEY in params: - if not isinstance(params[_BATCH_SIZE_KEY], int): - raise ValueError( - '`{}` in params must be an int'.format(_BATCH_SIZE_KEY)) - params = copy.deepcopy(params) + if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): + raise ValueError( + '{} are reserved keys but existed in params {}.'.format( + _RESERVED_PARAMS_KEYS, params)) + + if use_tpu: + if train_batch_size is None: + raise ValueError('`train_batch_size` cannot be `None`') + if not isinstance(train_batch_size, int): + raise ValueError('`train_batch_size` must be an int') + if train_batch_size < 1: + raise ValueError('`train_batch_size` must be positive') + # The specified batch size is the batch size for the entire computation. - # The input_fn is called per-shard, so we want to calculate the per-shard - # batch size and pass that. - if params[_BATCH_SIZE_KEY] % config.tpu_config.num_shards != 0: + # The input_fn and model_fn are called per-shard, so we want to calculate + # the per-shard batch size and pass that. + if train_batch_size % config.tpu_config.num_shards != 0: raise ValueError( 'batch size {} must be divisible by number of shards {}' - .format(params[_BATCH_SIZE_KEY], config.tpu_config.num_shards)) + .format(train_batch_size, config.tpu_config.num_shards)) if use_tpu: # Verifies the model_fn signature according to Estimator framework. @@ -201,7 +246,7 @@ class TpuEstimator(estimator_lib.Estimator): # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. - model_function = wrapped_model_fn(model_fn) + model_function = wrapped_model_fn(model_fn, train_batch_size) else: model_function = model_fn @@ -210,7 +255,8 @@ class TpuEstimator(estimator_lib.Estimator): model_dir=model_dir, config=config, params=params) - self.use_tpu = use_tpu + self._use_tpu = use_tpu + self._train_batch_size = train_batch_size def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -252,9 +298,9 @@ class TpuEstimator(estimator_lib.Estimator): labels - `Tensor` or dictionary of `Tensor` with labels. Raises: - ValueError: if input_fn takes invalid arguments. + ValueError: if input_fn takes invalid arguments or does not have `params`. """ - if not self.use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: + if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: return super(TpuEstimator, self)._call_input_fn(input_fn, mode) input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access @@ -262,12 +308,16 @@ class TpuEstimator(estimator_lib.Estimator): kwargs = {} if 'params' in input_fn_args: kwargs['params'] = self.params # a deep copy. + else: + raise ValueError('input_fn ({}) does not include params argument, ' + 'required by TPUEstimator to pass batch size as ' + 'params["batch_size"]'.format(input_fn)) if 'config' in input_fn_args: kwargs['config'] = config # Now for TPU training. - if 'params' in kwargs and _BATCH_SIZE_KEY in kwargs['params']: - kwargs['params'][_BATCH_SIZE_KEY] //= config.tpu_config.num_shards + per_shard_batch_size = _per_shard_batch_size(self._train_batch_size, config) + kwargs['params'][_BATCH_SIZE_KEY] = per_shard_batch_size job = _tpu_job(config) def placement_function(index): @@ -287,8 +337,10 @@ class TpuEstimator(estimator_lib.Estimator): labels.append(result[1]) else: features.append(result) + if not labels or all(l is None for l in labels): return _PerShardOutput(features), None + return _PerShardOutput(features), _PerShardOutput(labels) @@ -302,16 +354,22 @@ def _verify_estimator_spec(estimator_spec): return estimator_spec -def _call_model_fn(model_fn, features, labels, mode, config, params): +def _call_model_fn(model_fn, features, labels, mode, config, params, + require_params=False): """Calls the model_fn with required parameters.""" model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access kwargs = {} if 'mode' in model_fn_args: kwargs['mode'] = mode - if 'params' in model_fn_args: - kwargs['params'] = params if 'config' in model_fn_args: kwargs['config'] = config + if 'params' in model_fn_args: + kwargs['params'] = params + elif require_params: + raise ValueError( + 'model_fn ({}) does not include params argument, ' + 'required by TPUEstimator to pass batch size as ' + 'params[\'batch_size\']'.format(model_fn)) return model_fn(features=features, labels=labels, **kwargs) @@ -321,7 +379,7 @@ def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params): config = copy.deepcopy(config) params = copy.deepcopy(params) return _verify_estimator_spec(_call_model_fn( - model_fn, features, labels, mode, config, params)) + model_fn, features, labels, mode, config, params, require_params=True)) def _call_model_fn_without_tpu( @@ -416,10 +474,10 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): return (dequeue_fn, enqueue_fn) -def wrapped_model_fn(model_fn): +def wrapped_model_fn(model_fn, train_batch_size): """Returns a new model_fn, which wraps the TPU support.""" - def _model_fn(features, labels, mode, config, params=None): + def _model_fn(features, labels, mode, config, params): """model_fn.""" # TODO(jhseu): Move to EVAL and PREDICT to TPU. @@ -427,9 +485,8 @@ def wrapped_model_fn(model_fn): return _call_model_fn_without_tpu( model_fn, features, labels, mode, config, params) - # Now for TPU training. - if params is not None and _BATCH_SIZE_KEY in params: - params[_BATCH_SIZE_KEY] //= config.tpu_config.num_shards + # Now for TPU training. `params` is never `None`. + params[_BATCH_SIZE_KEY] = _per_shard_batch_size(train_batch_size, config) assert isinstance(features, _PerShardOutput) features = features.as_list()