Rename TpuEstimator to TPUEstimator and TpuConfig to TPUConfig to follow PEP8
naming conventions. PiperOrigin-RevId: 161704561
This commit is contained in:
parent
c9d03a568a
commit
4f54336348
@ -24,12 +24,12 @@ import collections
|
||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||
|
||||
|
||||
class TpuConfig(collections.namedtuple(
|
||||
'TpuConfig', ['iterations_per_loop', 'num_shards'])):
|
||||
class TPUConfig(collections.namedtuple(
|
||||
'TPUConfig', ['iterations_per_loop', 'num_shards'])):
|
||||
"""TPU related configuration required by `TPUEstimator`."""
|
||||
|
||||
def __new__(cls, iterations_per_loop=2, num_shards=2):
|
||||
return super(TpuConfig, cls).__new__(
|
||||
return super(TPUConfig, cls).__new__(
|
||||
cls,
|
||||
iterations_per_loop=iterations_per_loop,
|
||||
num_shards=num_shards)
|
||||
@ -40,7 +40,7 @@ class RunConfig(run_config_lib.RunConfig):
|
||||
|
||||
def __init__(self, tpu_config=None, **kwargs):
|
||||
super(RunConfig, self).__init__(**kwargs)
|
||||
self._tpu_config = tpu_config or TpuConfig()
|
||||
self._tpu_config = tpu_config or TPUConfig()
|
||||
|
||||
@property
|
||||
def tpu_config(self):
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""TpuEstimator class."""
|
||||
"""TPUEstimator class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -71,15 +71,15 @@ class InfeedThreadController(object):
|
||||
stop. It could be the cases that the `max_steps` is reached or some hook
|
||||
requests the stop in the monitored_session.
|
||||
|
||||
This controller (with coordination with `TpuInfeedSessionHook`) does the
|
||||
This controller (with coordination with `TPUInfeedSessionHook`) does the
|
||||
following:
|
||||
|
||||
1) It pre-infeeds one `batch` data for current TPU iterations.
|
||||
|
||||
2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch`
|
||||
2) When `before_run` of `TPUInfeedSessionHook` is called, one more `batch`
|
||||
data will be infed.
|
||||
|
||||
3) When `end` of `TpuInfeedSessionHook` is called, the thread will end
|
||||
3) When `end` of `TPUInfeedSessionHook` is called, the thread will end
|
||||
gracefully.
|
||||
|
||||
So, we might need to adjust the algorithrm here if the IO is slower than the
|
||||
@ -115,7 +115,7 @@ class InfeedThreadController(object):
|
||||
self._input_thd.join()
|
||||
|
||||
|
||||
class TpuInfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
class TPUInfeedSessionHook(session_run_hook.SessionRunHook):
|
||||
"""A Session hook setting up the TPU initialization and infeed.
|
||||
|
||||
This hook does two major things:
|
||||
@ -168,15 +168,15 @@ class _PerShardOutput(object):
|
||||
return self.output
|
||||
|
||||
|
||||
class TpuEstimator(estimator_lib.Estimator):
|
||||
class TPUEstimator(estimator_lib.Estimator):
|
||||
"""Estimator with TPU support.
|
||||
|
||||
TpuEstimator handles many of the details of running on TPU devices, such as
|
||||
TPUEstimator handles many of the details of running on TPU devices, such as
|
||||
replicating inputs and models for each core, and returning to host
|
||||
periodically to run hooks.
|
||||
|
||||
Note: For training (evaluate and predict support on TPU are not yet
|
||||
implemented), TpuEstimator transforms a global batch size in params to a
|
||||
implemented), TPUEstimator transforms a global batch size in params to a
|
||||
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
|
||||
specify `train_batch_size` in constructor, and then get the batch size for
|
||||
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
|
||||
@ -189,7 +189,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
params=None,
|
||||
use_tpu=True,
|
||||
train_batch_size=None):
|
||||
"""Constructs an `TpuEstimator` instance.
|
||||
"""Constructs an `TPUEstimator` instance.
|
||||
|
||||
Args:
|
||||
model_fn: Model function as required by `Estimator`. For training, the
|
||||
@ -208,7 +208,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
|
||||
applied to training. Evaluate and predict still happen on CPU.
|
||||
train_batch_size: An int representing the global training batch size.
|
||||
TpuEstimator transforms this global batch size to a per-shard batch
|
||||
TPUEstimator transforms this global batch size to a per-shard batch
|
||||
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
|
||||
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
|
||||
`config.tpu_config.num_shards`.
|
||||
@ -251,7 +251,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
else:
|
||||
model_function = model_fn
|
||||
|
||||
super(TpuEstimator, self).__init__(
|
||||
super(TPUEstimator, self).__init__(
|
||||
model_fn=model_function,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
@ -302,7 +302,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
ValueError: if input_fn takes invalid arguments or does not have `params`.
|
||||
"""
|
||||
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
||||
return super(TPUEstimator, self)._call_input_fn(input_fn, mode)
|
||||
|
||||
input_fn_args = util.fn_args(input_fn)
|
||||
config = self.config # a deep copy.
|
||||
@ -518,7 +518,7 @@ def wrapped_model_fn(model_fn, train_batch_size):
|
||||
]
|
||||
|
||||
hooks = [
|
||||
TpuInfeedSessionHook(config, enqueue_fn),
|
||||
TPUInfeedSessionHook(config, enqueue_fn),
|
||||
training.LoggingTensorHook(
|
||||
{'loss': array_ops.identity(loss),
|
||||
'step': training.get_global_step()},
|
||||
|
Loading…
Reference in New Issue
Block a user