Rename TpuEstimator to TPUEstimator and TpuConfig to TPUConfig to follow PEP8

naming conventions.

PiperOrigin-RevId: 161704561
This commit is contained in:
Jonathan Hseu 2017-07-12 12:59:50 -07:00 committed by TensorFlower Gardener
parent c9d03a568a
commit 4f54336348
2 changed files with 17 additions and 17 deletions

View File

@ -24,12 +24,12 @@ import collections
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
class TpuConfig(collections.namedtuple( class TPUConfig(collections.namedtuple(
'TpuConfig', ['iterations_per_loop', 'num_shards'])): 'TPUConfig', ['iterations_per_loop', 'num_shards'])):
"""TPU related configuration required by `TPUEstimator`.""" """TPU related configuration required by `TPUEstimator`."""
def __new__(cls, iterations_per_loop=2, num_shards=2): def __new__(cls, iterations_per_loop=2, num_shards=2):
return super(TpuConfig, cls).__new__( return super(TPUConfig, cls).__new__(
cls, cls,
iterations_per_loop=iterations_per_loop, iterations_per_loop=iterations_per_loop,
num_shards=num_shards) num_shards=num_shards)
@ -40,7 +40,7 @@ class RunConfig(run_config_lib.RunConfig):
def __init__(self, tpu_config=None, **kwargs): def __init__(self, tpu_config=None, **kwargs):
super(RunConfig, self).__init__(**kwargs) super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TpuConfig() self._tpu_config = tpu_config or TPUConfig()
@property @property
def tpu_config(self): def tpu_config(self):

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# =================================================================== # ===================================================================
"""TpuEstimator class.""" """TPUEstimator class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -71,15 +71,15 @@ class InfeedThreadController(object):
stop. It could be the cases that the `max_steps` is reached or some hook stop. It could be the cases that the `max_steps` is reached or some hook
requests the stop in the monitored_session. requests the stop in the monitored_session.
This controller (with coordination with `TpuInfeedSessionHook`) does the This controller (with coordination with `TPUInfeedSessionHook`) does the
following: following:
1) It pre-infeeds one `batch` data for current TPU iterations. 1) It pre-infeeds one `batch` data for current TPU iterations.
2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch` 2) When `before_run` of `TPUInfeedSessionHook` is called, one more `batch`
data will be infed. data will be infed.
3) When `end` of `TpuInfeedSessionHook` is called, the thread will end 3) When `end` of `TPUInfeedSessionHook` is called, the thread will end
gracefully. gracefully.
So, we might need to adjust the algorithrm here if the IO is slower than the So, we might need to adjust the algorithrm here if the IO is slower than the
@ -115,7 +115,7 @@ class InfeedThreadController(object):
self._input_thd.join() self._input_thd.join()
class TpuInfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedSessionHook(session_run_hook.SessionRunHook):
"""A Session hook setting up the TPU initialization and infeed. """A Session hook setting up the TPU initialization and infeed.
This hook does two major things: This hook does two major things:
@ -168,15 +168,15 @@ class _PerShardOutput(object):
return self.output return self.output
class TpuEstimator(estimator_lib.Estimator): class TPUEstimator(estimator_lib.Estimator):
"""Estimator with TPU support. """Estimator with TPU support.
TpuEstimator handles many of the details of running on TPU devices, such as TPUEstimator handles many of the details of running on TPU devices, such as
replicating inputs and models for each core, and returning to host replicating inputs and models for each core, and returning to host
periodically to run hooks. periodically to run hooks.
Note: For training (evaluate and predict support on TPU are not yet Note: For training (evaluate and predict support on TPU are not yet
implemented), TpuEstimator transforms a global batch size in params to a implemented), TPUEstimator transforms a global batch size in params to a
per-shard batch size when calling the `input_fn` and `model_fn`. Users should per-shard batch size when calling the `input_fn` and `model_fn`. Users should
specify `train_batch_size` in constructor, and then get the batch size for specify `train_batch_size` in constructor, and then get the batch size for
each shard in `input_fn` and `model_fn` by `params['batch_size']`. each shard in `input_fn` and `model_fn` by `params['batch_size']`.
@ -189,7 +189,7 @@ class TpuEstimator(estimator_lib.Estimator):
params=None, params=None,
use_tpu=True, use_tpu=True,
train_batch_size=None): train_batch_size=None):
"""Constructs an `TpuEstimator` instance. """Constructs an `TPUEstimator` instance.
Args: Args:
model_fn: Model function as required by `Estimator`. For training, the model_fn: Model function as required by `Estimator`. For training, the
@ -208,7 +208,7 @@ class TpuEstimator(estimator_lib.Estimator):
use_tpu: A bool indicating whether TPU support is enabled. Currently, only use_tpu: A bool indicating whether TPU support is enabled. Currently, only
applied to training. Evaluate and predict still happen on CPU. applied to training. Evaluate and predict still happen on CPU.
train_batch_size: An int representing the global training batch size. train_batch_size: An int representing the global training batch size.
TpuEstimator transforms this global batch size to a per-shard batch TPUEstimator transforms this global batch size to a per-shard batch
size, as params['batch_size'], when calling `input_fn` and `model_fn`. size, as params['batch_size'], when calling `input_fn` and `model_fn`.
Cannot be `None` if `use_tpu` is `True`. Must be divisible by Cannot be `None` if `use_tpu` is `True`. Must be divisible by
`config.tpu_config.num_shards`. `config.tpu_config.num_shards`.
@ -251,7 +251,7 @@ class TpuEstimator(estimator_lib.Estimator):
else: else:
model_function = model_fn model_function = model_fn
super(TpuEstimator, self).__init__( super(TPUEstimator, self).__init__(
model_fn=model_function, model_fn=model_function,
model_dir=model_dir, model_dir=model_dir,
config=config, config=config,
@ -302,7 +302,7 @@ class TpuEstimator(estimator_lib.Estimator):
ValueError: if input_fn takes invalid arguments or does not have `params`. ValueError: if input_fn takes invalid arguments or does not have `params`.
""" """
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
return super(TpuEstimator, self)._call_input_fn(input_fn, mode) return super(TPUEstimator, self)._call_input_fn(input_fn, mode)
input_fn_args = util.fn_args(input_fn) input_fn_args = util.fn_args(input_fn)
config = self.config # a deep copy. config = self.config # a deep copy.
@ -518,7 +518,7 @@ def wrapped_model_fn(model_fn, train_batch_size):
] ]
hooks = [ hooks = [
TpuInfeedSessionHook(config, enqueue_fn), TPUInfeedSessionHook(config, enqueue_fn),
training.LoggingTensorHook( training.LoggingTensorHook(
{'loss': array_ops.identity(loss), {'loss': array_ops.identity(loss),
'step': training.get_global_step()}, 'step': training.get_global_step()},