Rename TpuEstimator to TPUEstimator and TpuConfig to TPUConfig to follow PEP8
naming conventions. PiperOrigin-RevId: 161704561
This commit is contained in:
parent
c9d03a568a
commit
4f54336348
@ -24,12 +24,12 @@ import collections
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
|
||||||
|
|
||||||
|
|
||||||
class TpuConfig(collections.namedtuple(
|
class TPUConfig(collections.namedtuple(
|
||||||
'TpuConfig', ['iterations_per_loop', 'num_shards'])):
|
'TPUConfig', ['iterations_per_loop', 'num_shards'])):
|
||||||
"""TPU related configuration required by `TPUEstimator`."""
|
"""TPU related configuration required by `TPUEstimator`."""
|
||||||
|
|
||||||
def __new__(cls, iterations_per_loop=2, num_shards=2):
|
def __new__(cls, iterations_per_loop=2, num_shards=2):
|
||||||
return super(TpuConfig, cls).__new__(
|
return super(TPUConfig, cls).__new__(
|
||||||
cls,
|
cls,
|
||||||
iterations_per_loop=iterations_per_loop,
|
iterations_per_loop=iterations_per_loop,
|
||||||
num_shards=num_shards)
|
num_shards=num_shards)
|
||||||
@ -40,7 +40,7 @@ class RunConfig(run_config_lib.RunConfig):
|
|||||||
|
|
||||||
def __init__(self, tpu_config=None, **kwargs):
|
def __init__(self, tpu_config=None, **kwargs):
|
||||||
super(RunConfig, self).__init__(**kwargs)
|
super(RunConfig, self).__init__(**kwargs)
|
||||||
self._tpu_config = tpu_config or TpuConfig()
|
self._tpu_config = tpu_config or TPUConfig()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tpu_config(self):
|
def tpu_config(self):
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ===================================================================
|
# ===================================================================
|
||||||
|
|
||||||
"""TpuEstimator class."""
|
"""TPUEstimator class."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -71,15 +71,15 @@ class InfeedThreadController(object):
|
|||||||
stop. It could be the cases that the `max_steps` is reached or some hook
|
stop. It could be the cases that the `max_steps` is reached or some hook
|
||||||
requests the stop in the monitored_session.
|
requests the stop in the monitored_session.
|
||||||
|
|
||||||
This controller (with coordination with `TpuInfeedSessionHook`) does the
|
This controller (with coordination with `TPUInfeedSessionHook`) does the
|
||||||
following:
|
following:
|
||||||
|
|
||||||
1) It pre-infeeds one `batch` data for current TPU iterations.
|
1) It pre-infeeds one `batch` data for current TPU iterations.
|
||||||
|
|
||||||
2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch`
|
2) When `before_run` of `TPUInfeedSessionHook` is called, one more `batch`
|
||||||
data will be infed.
|
data will be infed.
|
||||||
|
|
||||||
3) When `end` of `TpuInfeedSessionHook` is called, the thread will end
|
3) When `end` of `TPUInfeedSessionHook` is called, the thread will end
|
||||||
gracefully.
|
gracefully.
|
||||||
|
|
||||||
So, we might need to adjust the algorithrm here if the IO is slower than the
|
So, we might need to adjust the algorithrm here if the IO is slower than the
|
||||||
@ -115,7 +115,7 @@ class InfeedThreadController(object):
|
|||||||
self._input_thd.join()
|
self._input_thd.join()
|
||||||
|
|
||||||
|
|
||||||
class TpuInfeedSessionHook(session_run_hook.SessionRunHook):
|
class TPUInfeedSessionHook(session_run_hook.SessionRunHook):
|
||||||
"""A Session hook setting up the TPU initialization and infeed.
|
"""A Session hook setting up the TPU initialization and infeed.
|
||||||
|
|
||||||
This hook does two major things:
|
This hook does two major things:
|
||||||
@ -168,15 +168,15 @@ class _PerShardOutput(object):
|
|||||||
return self.output
|
return self.output
|
||||||
|
|
||||||
|
|
||||||
class TpuEstimator(estimator_lib.Estimator):
|
class TPUEstimator(estimator_lib.Estimator):
|
||||||
"""Estimator with TPU support.
|
"""Estimator with TPU support.
|
||||||
|
|
||||||
TpuEstimator handles many of the details of running on TPU devices, such as
|
TPUEstimator handles many of the details of running on TPU devices, such as
|
||||||
replicating inputs and models for each core, and returning to host
|
replicating inputs and models for each core, and returning to host
|
||||||
periodically to run hooks.
|
periodically to run hooks.
|
||||||
|
|
||||||
Note: For training (evaluate and predict support on TPU are not yet
|
Note: For training (evaluate and predict support on TPU are not yet
|
||||||
implemented), TpuEstimator transforms a global batch size in params to a
|
implemented), TPUEstimator transforms a global batch size in params to a
|
||||||
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
|
per-shard batch size when calling the `input_fn` and `model_fn`. Users should
|
||||||
specify `train_batch_size` in constructor, and then get the batch size for
|
specify `train_batch_size` in constructor, and then get the batch size for
|
||||||
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
|
each shard in `input_fn` and `model_fn` by `params['batch_size']`.
|
||||||
@ -189,7 +189,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
params=None,
|
params=None,
|
||||||
use_tpu=True,
|
use_tpu=True,
|
||||||
train_batch_size=None):
|
train_batch_size=None):
|
||||||
"""Constructs an `TpuEstimator` instance.
|
"""Constructs an `TPUEstimator` instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_fn: Model function as required by `Estimator`. For training, the
|
model_fn: Model function as required by `Estimator`. For training, the
|
||||||
@ -208,7 +208,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
|
use_tpu: A bool indicating whether TPU support is enabled. Currently, only
|
||||||
applied to training. Evaluate and predict still happen on CPU.
|
applied to training. Evaluate and predict still happen on CPU.
|
||||||
train_batch_size: An int representing the global training batch size.
|
train_batch_size: An int representing the global training batch size.
|
||||||
TpuEstimator transforms this global batch size to a per-shard batch
|
TPUEstimator transforms this global batch size to a per-shard batch
|
||||||
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
|
size, as params['batch_size'], when calling `input_fn` and `model_fn`.
|
||||||
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
|
Cannot be `None` if `use_tpu` is `True`. Must be divisible by
|
||||||
`config.tpu_config.num_shards`.
|
`config.tpu_config.num_shards`.
|
||||||
@ -251,7 +251,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
else:
|
else:
|
||||||
model_function = model_fn
|
model_function = model_fn
|
||||||
|
|
||||||
super(TpuEstimator, self).__init__(
|
super(TPUEstimator, self).__init__(
|
||||||
model_fn=model_function,
|
model_fn=model_function,
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
@ -302,7 +302,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
ValueError: if input_fn takes invalid arguments or does not have `params`.
|
ValueError: if input_fn takes invalid arguments or does not have `params`.
|
||||||
"""
|
"""
|
||||||
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||||
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
return super(TPUEstimator, self)._call_input_fn(input_fn, mode)
|
||||||
|
|
||||||
input_fn_args = util.fn_args(input_fn)
|
input_fn_args = util.fn_args(input_fn)
|
||||||
config = self.config # a deep copy.
|
config = self.config # a deep copy.
|
||||||
@ -518,7 +518,7 @@ def wrapped_model_fn(model_fn, train_batch_size):
|
|||||||
]
|
]
|
||||||
|
|
||||||
hooks = [
|
hooks = [
|
||||||
TpuInfeedSessionHook(config, enqueue_fn),
|
TPUInfeedSessionHook(config, enqueue_fn),
|
||||||
training.LoggingTensorHook(
|
training.LoggingTensorHook(
|
||||||
{'loss': array_ops.identity(loss),
|
{'loss': array_ops.identity(loss),
|
||||||
'step': training.get_global_step()},
|
'step': training.get_global_step()},
|
||||||
|
Loading…
Reference in New Issue
Block a user