Delete the TF core copy of TPUEstimator in favor of the Estimator repo's

version.

PiperOrigin-RevId: 247493942
This commit is contained in:
Jonathan Hseu 2019-05-09 14:13:20 -07:00 committed by TensorFlower Gardener
parent c182167d3f
commit 72d36073e9
15 changed files with 67 additions and 6182 deletions

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu._tpu_estimator_embedding import *
from tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import *
# pylint: enable=wildcard-import,unused-import

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.error_handling import *
from tensorflow_estimator.python.estimator.tpu.error_handling import *
# pylint: enable=wildcard-import,unused-import

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_config import *
from tensorflow_estimator.python.estimator.tpu.tpu_config import *
# pylint: enable=wildcard-import,unused-import

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.tpu_context import *
from tensorflow_estimator.python.estimator.tpu.tpu_context import *
# pylint: enable=wildcard-import,unused-import

View File

@ -19,15 +19,15 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import,redefined-builtin
from tensorflow.python.tpu.tpu_estimator import *
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import *
# used by tests
from tensorflow.python.tpu.tpu_estimator import _clone_export_output_with_tensors
from tensorflow.python.tpu.tpu_estimator import _create_global_step
from tensorflow.python.tpu.tpu_estimator import _export_output_to_tensors
from tensorflow.python.tpu.tpu_estimator import _get_scaffold
from tensorflow.python.tpu.tpu_estimator import _Inputs
from tensorflow.python.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR
from tensorflow.python.tpu.tpu_estimator import _TPU_ENQUEUE_OPS
from tensorflow.python.tpu.tpu_estimator import _TPU_ESTIMATOR
from tensorflow.python.tpu.tpu_estimator import _TPU_TRAIN_OP
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _clone_export_output_with_tensors
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _create_global_step
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _export_output_to_tensors
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _get_scaffold
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _Inputs
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ENQUEUE_OPS
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ESTIMATOR
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_TRAIN_OP
# pylint: enable=wildcard-import,unused-import,redefined-builtin

View File

@ -19,5 +19,5 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.tpu.util import *
from tensorflow_estimator.python.estimator.tpu.util import *
# pylint: enable=wildcard-import,unused-import

View File

@ -275,30 +275,6 @@ tf_py_test(
],
)
tf_py_test(
name = "tpu_config_test",
size = "small",
srcs = ["tpu_config_test.py"],
additional_deps = [
":tpu_estimator",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
name = "tpu_estimator_signals_test",
size = "small",
srcs = ["tpu_estimator_signals_test.py"],
additional_deps = [
":tpu_estimator",
"//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
],
# TODO(jhseu): Remove. Fails in OSS on Python 3.
tags = ["no_oss"],
)
tf_py_test(
name = "topology_test",
size = "medium",

View File

@ -1,382 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Tooling for support TPU embedding in TPUEstimator."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.feature_column import feature_column as core_fc
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.tpu import feature_column as tpu_fc
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
from tensorflow.python.tpu.tpu_embedding import AdamParameters
from tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters
from tensorflow.python.training import training
# pylint: disable=protected-access
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
tpu_fc._TPUSharedEmbeddingColumn)
_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
core_fc_lib.EmbeddingColumn,
core_fc._SharedEmbeddingColumn)
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
_SUPPORTED_OPTIMIZERS = (AdagradParameters, AdamParameters,
StochasticGradientDescentParameters)
# pylint: enable=protected-access
_TABLE_NAME_PREFIX = 'tbl_'
_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
def _get_table_name_from_embedding_var_name(embedding_var_name):
return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
def _get_embedding_var_name_from_table_name(table_name):
return table_name[_LEN_TABLE_NAME_PREFIX:]
def _get_embedding_variable_name(scope_name, var_name):
return '{}/{}'.format(scope_name, var_name)
def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
"""Return embedding variable names which are consistent with CPU runs."""
if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
return tpu_embedding.AdagradSlotVariableName(
'{}/{}/Adagrad'.format(scope_name, var_name)
)
elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
return tpu_embedding.AdamSlotVariableNames(
'{}/{}/Adam/m'.format(scope_name, var_name),
'{}/{}/Adam/v'.format(scope_name, var_name)
)
elif isinstance(optimization_parameters,
tpu_embedding.StochasticGradientDescentParameters):
return None
else:
raise ValueError('Support to infer full variable name '
'for optimization_parameter {} has not been added.'
.format(optimization_parameters))
def get_full_variable_names(
graph, table_to_config_dict, optimization_parameters=None):
"""Return embedding variable names and slot variables which are consistent with CPU runs."""
collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
if not collection:
raise RuntimeError(
'Embedding feature column did not capture any thing. Make sure the '
'feature columns passed to TPUEstimator constructor is properly '
'used in model_fn.')
embedding_variable_name_by_table = {}
slot_variable_names_by_table = {}
for table_name in table_to_config_dict:
embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
(scope_name, var_name) = collection[0][embedding_var_name]
embedding_variable_name_by_table[table_name] = (
_get_embedding_variable_name(scope_name, var_name))
if optimization_parameters:
slot_variable_names_by_table[table_name] = _get_slot_variable_names(
scope_name, var_name, optimization_parameters)
graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
return embedding_variable_name_by_table, slot_variable_names_by_table
def get_configs_from_feature_columns(feature_columns):
"""Create configs for TPUEmbedding etc from a list of feature columns.
Args:
feature_columns: a list of supported feature columns.
Returns:
A tuple of dicts, the first maps tables to their config, the second maps
features to their config, and the third maps features to weight key names.
"""
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
for column in feature_columns:
if not isinstance(column, allowed):
raise TypeError(
'Unsupported feature column {}. Supported types are {}.'.format(
type(column), allowed))
table_to_config = {}
feature_to_config = {}
feature_to_weight_key_name = {}
for column in feature_columns:
feature_name = column.get_feature_key_name()
table_name = _get_table_name_from_embedding_var_name(
column.get_embedding_var_name())
if feature_name in feature_to_config:
raise ValueError(
'Feature column {} is used with multiple embeddings and this is '
'not supported.'.format(feature_name))
feature_to_config[feature_name] = tpu_embedding.FeatureConfig(
table_id=table_name)
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
vocabulary_size, dimension = column.get_embedding_table_size()
table_to_config[table_name] = tpu_embedding.TableConfig(
vocabulary_size=vocabulary_size,
dimension=dimension,
initializer=column.get_initializer(),
combiner=column.get_combiner())
return table_to_config, feature_to_config, feature_to_weight_key_name
class EmbeddingConfigSpec(
collections.namedtuple('EmbeddingConfigSpec', [
'feature_columns', 'optimization_parameters', 'clipping_limit',
'pipeline_execution_with_tensor_core',
'experimental_gradient_multiplier_fn'
])):
"""Class to keep track of embedding config specification."""
def __new__(cls,
feature_columns,
optimization_parameters,
clipping_limit=None,
pipeline_execution_with_tensor_core=False,
experimental_gradient_multiplier_fn=None):
"""Creates an EmbeddingConfigSpec instance.
Args:
feature_columns: All `FeatureColumn`s used by model.
optimization_parameters: An instance of `AdagradParameters`,
`AdamParameters` or `StochasticGradientDescentParameters`. This
optimizer will be applied to all embedding variables specified by
`feature_columns`.
clipping_limit: (Optional) Clipping limit (absolute value).
pipeline_execution_with_tensor_core: setting this to `True` makes training
faster, but trained model will be different if step N and step N+1
involve the same set of embedding IDs. Please see
`tpu_embedding_configuration.proto` for details.
experimental_gradient_multiplier_fn: (Optional) A Fn taking global step as
input returning the current multiplier for all embedding gradients.
Returns:
An EmbeddingConfigSpec instance.
Raises:
ValueError: If the feature_columns are not specified.
TypeError: If the feature columns are not of ths correct type (one of
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
_EMBEDDING_COLUMN_CLASSES).
ValueError: If `optimization_parameters` is not one of the required types.
"""
if not feature_columns:
raise ValueError('`feature_columns` cannot be `None` or empty.')
# It is unknown at this moment, whether the TPUEstimator is running in CPU
# or TPU mode. So allow non-TPU embedding columns also.
supported_classes = tuple(
list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
list(_EMBEDDING_COLUMN_CLASSES))
for column in feature_columns:
if not isinstance(column, supported_classes):
raise TypeError(
'All feature columns must be supported types in {}. Got {}'.format(
supported_classes, type(column)))
if not isinstance(optimization_parameters, _SUPPORTED_OPTIMIZERS):
raise ValueError('optimization_parameters must be an instance of type '
'{}. Got {}.'.format(_SUPPORTED_OPTIMIZERS,
type(optimization_parameters)))
return super(EmbeddingConfigSpec, cls).__new__(
cls,
feature_columns=feature_columns,
optimization_parameters=optimization_parameters,
clipping_limit=clipping_limit,
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core,
experimental_gradient_multiplier_fn=experimental_gradient_multiplier_fn)
class EmbeddingConfig(object):
"""This is the internal immutable object for embedding config.
`_EmbeddingConfig` is responsible to _translate_ user provided
`EmbeddingConfigSpec` to internal data structures, mostly constructor
arguments of `TPUEmbedding`.
"""
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
num_hosts, num_cores, run_config):
if not embedding_config_spec:
raise ValueError('embedding_config_spec cannot be None.')
self._embedding_config_spec = embedding_config_spec
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._num_hosts = num_hosts
self._num_cores = num_cores
self._run_config = run_config
(self._table_to_config_dict, self._feature_to_config_dict,
self.feature_to_weight_key_name_dict) = (
get_configs_from_feature_columns(
embedding_config_spec.feature_columns))
self._mode_to_tpu_embedding_dict = {}
self.dummy_table_variables = None
self._grad_multiplier_fn = (
embedding_config_spec.experimental_gradient_multiplier_fn)
def get_grad_multiplier(self):
if self._grad_multiplier_fn:
return ops.convert_to_tensor(
self._grad_multiplier_fn(training.get_global_step()),
dtype=dtypes.float32)
def has_embedding_tables(self):
return bool(self._table_to_config_dict)
def _create_tpu_embedding(self, mode):
"""Create tpu_embedding.TPUEmbedding based on mode."""
if mode == model_fn_lib.ModeKeys.TRAIN:
batch_size = self._train_batch_size
else:
batch_size = self._eval_batch_size
if mode == model_fn_lib.ModeKeys.TRAIN:
tpu_embedding_mode = tpu_embedding.TRAINING
optimization_parameters = (
self._embedding_config_spec.optimization_parameters)
elif (mode == model_fn_lib.ModeKeys.EVAL or
mode == model_fn_lib.ModeKeys.PREDICT):
tpu_embedding_mode = tpu_embedding.INFERENCE
optimization_parameters = None
else:
raise ValueError('Mode {} is not supported.'.format(mode))
if self._run_config.cluster:
master = self._run_config.cluster.master()
cluster_spec = self._run_config.cluster.cluster_spec()
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
else:
master = (
self._run_config.evaluation_master
if mode == model_fn_lib.ModeKeys.EVAL else self._run_config.master)
cluster_def = None
tpu_embedding_ = tpu_embedding.TPUEmbedding(
self._table_to_config_dict,
self._feature_to_config_dict,
batch_size,
tpu_embedding_mode,
master,
optimization_parameters,
cluster_def,
pipeline_execution_with_tensor_core=self._embedding_config_spec
.pipeline_execution_with_tensor_core)
return tpu_embedding_
def get_tpu_embedding(self, mode):
if mode not in self._mode_to_tpu_embedding_dict:
self._mode_to_tpu_embedding_dict[mode] = (
self._create_tpu_embedding(mode))
return self._mode_to_tpu_embedding_dict[mode]
def split_inputs(ctx, features, labels):
"""Splits the dense and sparse tensors inside the features and labels."""
enqueue_datas = collections.OrderedDict()
if ctx.embedding_config:
tpu_embedding_ = ctx.embedding_config.tpu_embedding
feature_to_weight_key_name_dict = (
ctx.embedding_config.feature_to_weight_key_name_dict)
for feature_key in tpu_embedding_.feature_to_config_dict:
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
weight_key_name = feature_to_weight_key_name_dict[feature_key]
if isinstance(sparse_feature, sparse_tensor.SparseTensor):
weights = _get_weights_from_features(weight_key_name, features)
enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
sparse_feature, weights)
else:
if weight_key_name is not None:
raise ValueError(
'Found weights {} for weighted_categorical_column, which is not'
'compatible with sparse feature {} enqueued as dense tensor.'
.format(weight_key_name, feature_key))
enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
enqueue_datas[feature_key] = enqueue_data
return features, labels, enqueue_datas
def _get_sparse_feature_from_feature(feature_key, features):
"""Pop and return sparse feature."""
sparse_feature = features.pop(feature_key)
if not sparse_feature.dtype.is_integer:
raise ValueError('SparseTensor with string as values are not supported. '
'If you are using vocabulary_file_categorical_column or '
'vocabulary_list_categorical_column, please call '
'your_column.categorical_column._transform_feature({{'
'your_column.key: features[your_column.key]}}) in'
'your input_fn() to convert string to int. '
'feature_key = {}.'.format(feature_key))
return sparse_feature
def _get_weights_from_features(weight_key_name, features):
"""Pop and return feature for weights, possibly None."""
weights = None
if weight_key_name is not None:
if weight_key_name in features:
weights = features.pop(weight_key_name)
else:
raise ValueError(
'Cannot find weights {} for weighted_categorical_column.'
' Please check if the weights are present in feature dict. Also'
' note weight-sharing among weighted_categorical_column is not '
'supported on TPU.'.format(weight_key_name))
if not isinstance(weights, sparse_tensor.SparseTensor):
raise ValueError(
'weighted_categorical_column with weight key name {} has dense '
'weights. Dense weights are not supported on TPU. Please use '
'sparse weights instead.'.format(weight_key_name))
if weights.dtype is not dtypes.float32:
weights = math_ops.to_float(weights)
return weights
def get_tpu_embedding_columns(feature_columns):
"""Get feature columns meant to use TPU embedding.
Args:
feature_columns: a list of feature columns.
Returns:
A list of feature columns which can be placed on TPU embedding.
"""
tpu_embedding_columns = []
for column in feature_columns:
if isinstance(column, _TPU_EMBEDDING_COLUMN_CLASSES):
tpu_embedding_columns.append(column)
return tpu_embedding_columns
# pylint: disable=wildcard-import,unused-import
from tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,135 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""ErrorRendezvous handler for collecting errors from multiple threads."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import sys
import threading
import time
import six
from tensorflow.python.framework import errors
from tensorflow.python.platform import tf_logging as logging
_UNINTERESTING_ERRORS = (errors.CancelledError,)
class ErrorRendezvous(object):
"""Resolve errors from multiple threads during TPU execution.
TPU errors can occur on the infeed or outfeed threads as well as the main
training thread.
Depending on which thread "wins" and receives the session error first, we may
end up showing users a confusing and non-actionable error message (session
cancelled) instead of a root cause (e.g. a bad filename).
The rendezvous object provides a location to capture these errors until all
threads terminate. At that point we can choose the most informative error
to report.
"""
def __init__(self, num_sources):
# string -> (message, traceback)
self._errors = {}
self._num_sources = num_sources
self._session_cancel_timer = None
def record_error(self, source, exc_info, session=None):
"""Report an exception from the given source.
If a session is passed, a timer will be registered to close it after a few
seconds. This is necessary to ensure the main training loop does not hang
if an infeed/oufeed error occurs. We sleep a few seconds to allow a more
interesting error from another thread to propagate.
Args:
source: string, source of the error
exc_info: Output from `sys.exc_info` (type, value, traceback)
session: Session to close after delay.
"""
_, value, _ = exc_info
self._errors[source] = exc_info
logging.error('Error recorded from %s: %s', source, value)
if session is not None and self._session_cancel_timer is None:
def _cancel_session():
time.sleep(5)
logging.error('Closing session due to error %s' % value)
try:
session.close()
except: # pylint: disable=bare-except
logging.error(
'\n\n\nFailed to close session after error.'
'Other threads may hang.\n\n\n')
self._session_cancel_timer = threading.Thread(target=_cancel_session,)
self._session_cancel_timer.daemon = True
self._session_cancel_timer.start()
def record_done(self, source):
"""Mark execution source `source` as done.
If an error was originally reported from `source` it is left intact.
Args:
source: `str`, source being recorded
"""
logging.info('%s marked as finished', source)
if source not in self._errors:
self._errors[source] = None
@contextlib.contextmanager
def catch_errors(self, source, session=None):
"""Context manager to report any errors within a block."""
try:
yield
except Exception: # pylint: disable=broad-except
self.record_error(source, sys.exc_info(), session)
def raise_errors(self, timeout_sec=0):
"""Wait for up to `timeout` seconds for all error sources to finish.
Preferentially raise "interesting" errors (errors not in the
_UNINTERESTING_ERRORS) set.
Args:
timeout_sec: Seconds to wait for other error sources.
"""
for _ in range(timeout_sec):
if len(self._errors) == self._num_sources:
break
time.sleep(1)
kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
# First check for any interesting errors, then fall back on the session
# cancelled errors etc.
for k, (typ, value, traceback) in kept_errors:
if isinstance(value, _UNINTERESTING_ERRORS):
continue
else:
logging.warn('Reraising captured error')
six.reraise(typ, value, traceback)
for k, (typ, value, traceback) in kept_errors:
logging.warn('Reraising captured error')
six.reraise(typ, value, traceback)
# pylint: disable=wildcard-import,unused-import
from tensorflow_estimator.python.estimator.tpu.error_handling import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,295 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""A RunConfig subclass with TPU support."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import os
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import util as util_lib
# pylint: disable=protected-access
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
# pylint: enable=protected-access
class InputPipelineConfig(object):
r"""Please see the definition of these values in TPUConfig."""
PER_SHARD_V1 = 1
PER_HOST_V1 = 2
PER_HOST_V2 = 3
BROADCAST = 4
SLICED = 5
class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
'input_partition_dims',
'eval_training_input_configuration',
])):
r"""TPU related configuration required by `TPUEstimator`.
Args:
iterations_per_loop: This is the number of train steps running in TPU
system before returning to CPU host for each `Session.run`. This means
global step is increased `iterations_per_loop` times in one `Session.run`.
It is recommended to be set as number of global steps for next checkpoint.
Note that in evaluation don't use this value, instead we run total eval
`steps` on TPU for a single `Session.run`.
num_shards: (Deprecated, ignored by TPUEstimator).
The number of model replicas in the system. For non-model-parallelism
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
num_cores_per_replica * num_shards.
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
An integer which describes the number of TPU cores per model replica. This
is required by model-parallelism which enables partitioning
the model to multiple cores. Currently num_cores_per_replica must be
1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
`input_fn` is invoked once on each host. With the per-core input pipeline
configuration, it is invoked once for each core.
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
the batch size for each shard is `train_batch_size` // #hosts in the
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
invoked once on host 0 and the tensors are broadcasted to all other
replicas. The batch size equals to `train_batch_size`. With the per-core
input pipeline configuration, the shard batch size is also
`train_batch_size` // #cores.
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
within TPUEstimator, however when using ClusterSpec propagation in more
esoteric cluster configurations, you may need to specify the job name as a
string.
initial_infeed_sleep_secs: The number of seconds the infeed thread should
wait before enqueueing the first batch. This helps avoid timeouts for
models that require a long compilation time.
input_partition_dims: A nested list to describe the partition dims
for all the tensors from input_fn(). The structure of
input_partition_dims must match the structure of `features` and
`labels` from input_fn(). The total number of partitions must match
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
images with shape [N, H, W, C] and labels [N].
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
to all the TPU cores since the partition dims is `None`.
Current limitations: This feature is only supported with the PER_HOST_V2
input mode.
eval_training_input_configuration: If `SLICED`, `input_fn` is only
invoked once on host 0 and the tensors are broadcasted to all other
replicas. Unlike per_host_input_for_training=BROADCAST, each replica will
only get a slice of the data instead of a whole copy. If `PER_HOST_V1`,
the behaviour is determined by per_host_input_for_training.
Raises:
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(
cls,
iterations_per_loop=2,
num_shards=None,
num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None,
input_partition_dims=None,
eval_training_input_configuration=InputPipelineConfig.PER_HOST_V1):
# Check iterations_per_loop.
util_lib.check_positive_integer(iterations_per_loop,
'TPUConfig iterations_per_loop')
# Check num_shards.
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
if input_partition_dims is not None:
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
raise ValueError(
'input_partition_dims must be a list/tuple with one or two'
' elements.')
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
raise ValueError(
'input_partition_dims is only supported in PER_HOST_V2 mode.')
if num_cores_per_replica is None:
raise ValueError(
'input_partition_dims requires setting num_cores_per_replica.')
# Check num_cores_per_replica
if num_cores_per_replica is not None:
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
if eval_training_input_configuration not in [
InputPipelineConfig.PER_HOST_V1, InputPipelineConfig.SLICED
]:
raise ValueError(
'eval_training_input_configuration must be PER_HOST_V1 or SLICED;'
' got {}'.format(str(eval_training_input_configuration)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
if per_host_input_for_training is False:
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
elif per_host_input_for_training is True:
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
# Check initial_infeed_sleep_secs.
if initial_infeed_sleep_secs:
util_lib.check_positive_integer(initial_infeed_sleep_secs,
'TPUConfig initial_infeed_sleep_secs')
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
return super(TPUConfig, cls).__new__(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
input_partition_dims=input_partition_dims,
eval_training_input_configuration=eval_training_input_configuration)
class RunConfig(run_config_lib.RunConfig):
"""RunConfig with TPU support."""
def __init__(self,
tpu_config=None,
evaluation_master=None,
master=None,
cluster=None,
**kwargs):
"""Constructs a RunConfig.
Args:
tpu_config: the TPUConfig that specifies TPU-specific configuration.
evaluation_master: a string. The address of the master to use for eval.
Defaults to master if not set.
master: a string. The address of the master to use for training.
cluster: a ClusterResolver
**kwargs: keyword config parameters.
Raises:
ValueError: if cluster is not None and the provided session_config has a
cluster_def already.
"""
super(RunConfig, self).__init__(**kwargs)
self._tpu_config = tpu_config or TPUConfig()
self._cluster = cluster
# If user sets master and/or evaluation_master explicitly, including empty
# string '', take it. Otherwise, take the values set by parent class.
if master is not None:
if cluster is not None:
raise ValueError('Both master and cluster are set.')
self._master = master
else:
if cluster:
self._master = cluster.master()
if evaluation_master is not None:
self._evaluation_master = evaluation_master
elif (not self._evaluation_master and
self.task_type != run_config_lib.TaskType.EVALUATOR):
# If the task type is EVALUATOR, it means some cluster manager sets the
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
#
# Otherwise, it means user executes the code without external cluster
# manager. For that, we optimize the user experience by setting
# evaluation_master to master, unless user overwrites it.
self._evaluation_master = self._master
# Set the ClusterSpec to use
if cluster:
self._cluster_spec = cluster.cluster_spec()
# Merge the cluster_def into the ConfigProto.
if self._session_config is None: # pylint: disable=access-member-before-definition
self._session_config = config_pb2.ConfigProto(
allow_soft_placement=True, isolate_session_state=True)
if self._session_config.HasField('cluster_def'):
raise ValueError(
'You cannot provide a ClusterResolver and '
'session_config.cluster_def.')
if self._cluster_spec:
self._session_config.cluster_def.CopyFrom(
self._cluster_spec.as_cluster_def())
def _maybe_overwrite_session_config_for_distributed_training(self):
# Overrides the parent class session_config overwrite for between-graph. TPU
# runs with in-graph, which should not have device filter. Doing nothing
# ("pass") basically disables it.
pass
@property
def evaluation_master(self):
return self._evaluation_master
@property
def master(self):
return self._master
@property
def tpu_config(self):
return self._tpu_config
@property
def cluster(self):
return self._cluster
def replace(self, **kwargs):
if 'tpu_config' not in kwargs:
return super(RunConfig, self).replace(**kwargs)
tpu_config = kwargs.pop('tpu_config')
new_instance = super(RunConfig, self).replace(**kwargs)
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
return new_instance
def _get_tpu_job_name_from_tf_config():
"""Extracts the TPU job name from TF_CONFIG env variable."""
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
# spec propagation.
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
if tpu_job_name:
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
return tpu_job_name
# pylint: disable=wildcard-import,unused-import
from tensorflow_estimator.python.estimator.tpu.tpu_config import *
# pylint: enable=wildcard-import,unused-import

View File

@ -1,181 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU RunConfig tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_config as tpu_config_lib
def _set_tf_config_env_variable(tf_config):
return test.mock.patch.dict('os.environ', {
'TF_CONFIG': json.dumps(tf_config)
})
class TPURunConfigTest(test.TestCase):
def test_no_session_config_set_in_local_case(self):
run_config = tpu_config_lib.RunConfig()
self.assertIsNone(run_config.session_config)
def test_no_session_config_overwrite_in_local_case(self):
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
run_config = tpu_config_lib.RunConfig(session_config=session_config)
self.assertEqual(session_config, run_config.session_config)
def test_no_session_config_set_with_cluster_spec(self):
tf_config = {
'cluster': {
run_config_lib.TaskType.CHIEF: ['host3:3'],
run_config_lib.TaskType.WORKER: ['host3:4']
},
'task': {
'type': run_config_lib.TaskType.CHIEF,
'index': 0
}
}
with _set_tf_config_env_variable(tf_config):
run_config = tpu_config_lib.RunConfig()
self.assertIsNone(run_config.session_config)
def test_no_session_config_overwrite_with_cluster_spec(self):
tf_config = {
'cluster': {
run_config_lib.TaskType.CHIEF: ['host3:3'],
run_config_lib.TaskType.WORKER: ['host3:4']
},
'task': {
'type': run_config_lib.TaskType.CHIEF,
'index': 0
}
}
with _set_tf_config_env_variable(tf_config):
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
run_config = tpu_config_lib.RunConfig(session_config=session_config)
self.assertEqual(session_config, run_config.session_config)
def test_fail_with_invalid_num_shards(self):
with self.assertRaisesRegexp(ValueError, 'must be positive'):
tpu_config_lib.RunConfig(
tpu_config=tpu_config_lib.TPUConfig(num_shards=0))
def test_fail_with_iterations_per_loop(self):
with self.assertRaisesRegexp(ValueError, 'must be positive'):
tpu_config_lib.RunConfig(
tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
def test_fail_with_invalid_num_cores_per_replica(self):
with self.assertRaisesRegexp(
ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
' got 7'):
tpu_config_lib.TPUConfig(num_cores_per_replica=7)
class TPURunConfigMasterTest(test.TestCase):
def test_default_values(self):
run_config = tpu_config_lib.RunConfig()
self.assertEqual('', run_config.master)
self.assertEqual('', run_config.evaluation_master)
def test_user_provided_master_and_evaluation_master(self):
run_config = tpu_config_lib.RunConfig(
master='_master_123', evaluation_master='_eval_master_123')
self.assertEqual('_master_123', run_config.master)
self.assertEqual('_eval_master_123', run_config.evaluation_master)
def test_evaluation_master_defaults_to_master(self):
run_config = tpu_config_lib.RunConfig(master='_master_123')
self.assertEqual('_master_123', run_config.master)
self.assertEqual('_master_123', run_config.evaluation_master)
def test_tf_config(self):
tf_config = {
'session_master': '_master_123',
'eval_session_master': '_eval_master_123'
}
with _set_tf_config_env_variable(tf_config):
run_config = tpu_config_lib.RunConfig()
self.assertEqual('_master_123', run_config.master)
self.assertEqual('_eval_master_123', run_config.evaluation_master)
def test_evaluation_master_defaults_to_master_in_tf_config(self):
tf_config = {
'session_master': '_master_123',
}
with _set_tf_config_env_variable(tf_config):
run_config = tpu_config_lib.RunConfig()
self.assertEqual('_master_123', run_config.master)
self.assertEqual('_master_123', run_config.evaluation_master)
def test_respect_evaluation_master_in_tf_config(self):
tf_config = {
'cluster': {
run_config_lib.TaskType.CHIEF: ['host0:0'],
},
'task': {
'type': run_config_lib.TaskType.EVALUATOR,
'index': 0
},
}
with _set_tf_config_env_variable(tf_config):
run_config = tpu_config_lib.RunConfig(master='_something')
self.assertEqual('', run_config.evaluation_master)
def test_user_overwrites_tf_config(self):
tf_config = {
'session_master': '_master_123',
'eval_session_master': '_eval_master_123'
}
with _set_tf_config_env_variable(tf_config):
run_config = tpu_config_lib.RunConfig(
master='_new_master_123', evaluation_master='_new_eval_master_123')
self.assertEqual('_new_master_123', run_config.master)
self.assertEqual('_new_eval_master_123', run_config.evaluation_master)
def test_user_overwrites_master_in_tf_config(self):
tf_config = {
'session_master': '_master_123',
'eval_session_master': '_eval_master_123'
}
with _set_tf_config_env_variable(tf_config):
run_config = tpu_config_lib.RunConfig(master='_new_master_123')
self.assertEqual('_new_master_123', run_config.master)
self.assertEqual('_eval_master_123', run_config.evaluation_master)
class TPUJobNameTest(test.TestCase):
def test_default_name(self):
config = tpu_config_lib.RunConfig()
self.assertIsNone(config.tpu_config.tpu_job_name)
def test_with_tf_config(self):
tf_config = {'service': {'tpu_worker_job_name': '_my_new_name',}}
with _set_tf_config_env_variable(tf_config):
config = tpu_config_lib.RunConfig()
self.assertEqual('_my_new_name', config.tpu_config.tpu_job_name)
if __name__ == '__main__':
test.main()

View File

@ -1,749 +1,23 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from contextlib import contextmanager
import copy
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import _tpu_estimator_embedding
from tensorflow.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.python.tpu import tpu_config
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
_NUM_CORES_TO_COMPUTATION_SHAPE = {
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
8: [2, 2, 2],
16: [4, 2, 2],
}
class TPUContext(object):
"""A context that holds the current configuration of the TPU computation."""
def __init__(self,
internal_ctx,
input_device=None,
invocation_index=None,
call_from_input_fn=True):
self._internal_ctx = internal_ctx
self._input_device = input_device
self._invocation_index = invocation_index
self._call_from_input_fn = call_from_input_fn
def current_input_fn_deployment(self):
"""The configuration of the current input_fn invocation.
The configuration depends on `TPUConfig.per_host_input_for_training`. See
`TPUConfig` for details.
Only set in params dict of input_fn
Returns:
A tuple of
1. Device spec string: String, is the current CPU host where the
input_fn is invoked.
2. Current invocation index: Int, 0-based index of the input_fn
invocation. See next item for details.
3. Total invocation count: Int, the total number of times to invoke the
input_fn on all CPU hosts. Each invocation will be passed with a new
`TPUContext` instance with current invocation index set properly.
4. Total number of replicas consumed by current_invocation: Int, the
number of replicas fed by the data returned by current input_fn. For
example, for per_core input pipeline deployment
and non-model-parallelism, total invocation count is equal to
the number of cores in the system and num replicas consumed by
current invocation is 1. For per-host v2 input pipeline deployment,
total invocation count is equal to the number of hosts in the system
and num replicas consumed by current invocation is equal to number of
cores per host.
Raises:
RuntimeError: If this method must not be called from input_fn.
"""
if not self._call_from_input_fn:
raise RuntimeError('This TPUContext instance must not be called from'
' model_fn.')
if self._internal_ctx.is_input_sharded_per_core():
total_invocation_count = (self._internal_ctx.num_hosts
* self._internal_ctx.num_of_replicas_per_host)
replicas_consumed = 1
elif self._internal_ctx.is_input_broadcast_with_iterators():
total_invocation_count = 1
replicas_consumed = self._internal_ctx.num_replicas
else:
total_invocation_count = self._internal_ctx.num_hosts
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
return (self._input_device, self._invocation_index,
total_invocation_count, replicas_consumed)
@property
def num_replicas(self):
"""The total number of replicas.
For non-model-parallelism, num_replicas should be the total num of TPU
cores in the system.
Returns:
The number of replicas.
"""
return self._internal_ctx.num_replicas
@property
def num_hosts(self):
"""The number of hosts for the TPU system."""
return self._internal_ctx.num_hosts
@property
def current_host(self):
"""The current host index for the TPU system."""
return self._invocation_index
@property
def num_of_replicas_per_host(self):
"""The number of replicas for each host."""
if self._internal_ctx.model_parallelism_enabled:
raise ValueError(
'num_of_replicas_per_host is not supported for model_parallelism')
return self._internal_ctx.num_of_replicas_per_host
@property
def device_assignment(self):
"""Returns device_assignment object."""
if self._call_from_input_fn:
raise RuntimeError('This TPUContext instance must not be called from'
' input_fn.')
return self._internal_ctx.device_assignment
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
This should be used for full replicate for non-model-parallelism.
Args:
replica_id: Int, the replica index.
Returns:
A tuple of device spec for CPU device and int device ordinal.
"""
# Note that: For the non-model parallelism, the mapping could be
# a random permutation. The order should not matter in most cases
# as far as model is replicated to all cores in the system.
return self._internal_ctx.device_for_replica(replica_id)
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function.
The place function takes host_id as the input and returns the TF device
for the correspoding host.
"""
def _placement_function(host_id):
"""Return the host device given host_id."""
return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
return _placement_function
class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.
This immutable object holds TPUEstimator config, train/eval batch size, and
`TPUEstimator.use_tpu`, which is expected to be passed around. It also
provides utility functions, based on the current state, to determine other
information commonly required by TPU computation, such as TPU device names,
TPU hosts, shard batch size, etc.
if eval_on_tpu is False, then execution of eval on TPU is disabled.
if eval_on_tpu is True, but use_tpu is False, a warning is issued,
and TPU execution is disabled for all modes.
N.B. As `mode` is not immutable state in Estimator, but essential to
distinguish between TPU training and evaluation, a common usage for
_InternalTPUContext with `mode` is as follows:
```
with _ctx.with_mode(mode) as ctx:
if ctx.is_running_on_cpu():
...
```
"""
def __init__(self,
config,
train_batch_size,
eval_batch_size,
predict_batch_size,
use_tpu,
eval_on_tpu=True,
embedding_config_spec=None):
self._config = config
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._predict_batch_size = predict_batch_size
self._use_tpu = use_tpu
logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
if not use_tpu and eval_on_tpu:
logging.warning('eval_on_tpu ignored because use_tpu is False.')
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
num_cores_per_replica = config.tpu_config.num_cores_per_replica
if self._model_parallelism_enabled:
self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
num_cores_per_replica]
else:
self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
self._embedding_config_spec = embedding_config_spec
self._lazy_embedding_config_dict = {} # key by master address
def _assert_mode(self):
if self._mode is None:
raise RuntimeError(
'`mode` needs to be set via contextmanager `with_mode`.')
return self._mode
@contextmanager
def with_mode(self, mode):
# NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
# such as _lazy_tpu_system_metadata_dict between new copy and the original
# one. Note that all lazy states stored in properties _lazy_foo are sort of
# immutable as they should be same for the process lifetime.
new_ctx = copy.copy(self)
new_ctx._mode = mode # pylint: disable=protected-access
yield new_ctx
@property
def mode(self):
return self._assert_mode()
def _get_master_address(self):
mode = self._assert_mode()
config = self._config
master = (
config.master
if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
return master
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
master = self._get_master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
cluster_def = None
if (self._config.session_config and
self._config.session_config.cluster_def.job):
cluster_def = self._config.session_config.cluster_def
# pylint: disable=protected-access
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
query_topology=self.model_parallelism_enabled))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
def _get_device_assignment(self):
"""Gets the (maybe cached) TPU device assignment."""
master = self._get_master_address()
device_assignment = self._lazy_device_assignment_dict.get(master)
if device_assignment is not None:
return device_assignment
tpu_system_metadata = self._get_tpu_system_metadata()
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
logging.info('num_cores_per_replica: %s',
str(self._config.tpu_config.num_cores_per_replica))
logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
logging.info('device_assignment.core_assignment: %s',
str(device_assignment.core_assignment))
self._lazy_device_assignment_dict[master] = device_assignment
return device_assignment
@property
def embedding_config(self):
"""Returns the embedding config based on current mode."""
master = self._get_master_address()
if master in self._lazy_embedding_config_dict:
embedding_config = self._lazy_embedding_config_dict[master]
else:
embedding_config = None
if self._use_tpu and self._embedding_config_spec:
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
self._embedding_config_spec, self._train_batch_size,
self._eval_batch_size, self.num_hosts, self.num_cores, self.config)
if not embedding_config.has_embedding_tables():
embedding_config = None
self._lazy_embedding_config_dict[master] = embedding_config
if embedding_config is not None:
mode = self._assert_mode()
# Dynamically attach tpu_embedding based on mode. With
# this, we could keep embedding_config immutable but call site always
# accesses the unified API '.tpu_embedding'.
embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
return embedding_config
@property
def model_parallelism_enabled(self):
return self._model_parallelism_enabled
@property
def input_partition_dims(self):
return self._config.tpu_config.input_partition_dims
@property
def device_assignment(self):
return (self._get_device_assignment()
if self._model_parallelism_enabled else None)
@property
def num_of_cores_per_host(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_of_cores_per_host
@property
def num_cores(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_cores
@property
def num_of_replicas_per_host(self):
"""Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
return self.num_of_cores_per_host
@property
def num_replicas(self):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica, is larger than the total num of '
'TPU cores in the system. num_cores_per_replica: {}, num cores '
'in the system: {}'.format(num_cores_per_replica,
num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
else:
return num_cores_in_system
@property
def num_hosts(self):
metadata = self._get_tpu_system_metadata()
return metadata.num_hosts
@property
def config(self):
return self._config
def is_input_sharded_per_core(self):
"""Return true if input_fn is invoked per-core (other than per-host)."""
mode = self._assert_mode()
return (mode == model_fn_lib.ModeKeys.TRAIN and
(self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1))
def is_input_per_host_with_iterators(self):
"""Return true if input_fn should be run in the per-host v2 config."""
return (self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_HOST_V2)
def is_input_broadcast_with_iterators(self):
"""Return true if input_fn should be run in the full_replicae config."""
mode = self._assert_mode()
return ((self._config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.BROADCAST) or
(mode != model_fn_lib.ModeKeys.TRAIN and
self._config.tpu_config.eval_training_input_configuration is
tpu_config.InputPipelineConfig.SLICED))
def is_running_on_cpu(self, is_export_mode=False):
"""Determines whether the input_fn and model_fn should be invoked on CPU.
This API also validates user provided configuration, such as batch size,
according the lazy initialized TPU system metadata.
Args:
is_export_mode: Indicates whether the current mode is for exporting the
model, when mode == PREDICT. Only with this bool, we could
tell whether user is calling the Estimator.predict or
Estimator.export_savedmodel, which are running on TPU and CPU
respectively. Parent class Estimator does not distinguish these two.
Returns:
bool, whether current input_fn or model_fn should be running on CPU.
Raises:
ValueError: any configuration is invalid.
"""
is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
if not is_running_on_cpu:
self._validate_tpu_configuration()
return is_running_on_cpu
def _is_running_on_cpu(self, is_export_mode):
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
mode = self._assert_mode()
if not self._use_tpu:
return True
if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
return True
if is_export_mode:
return True
return False
@property
def global_batch_size(self):
mode = self._assert_mode()
if mode == model_fn_lib.ModeKeys.TRAIN:
return self._train_batch_size
elif mode == model_fn_lib.ModeKeys.EVAL:
return self._eval_batch_size
elif mode == model_fn_lib.ModeKeys.PREDICT:
return self._predict_batch_size
else:
return None
@property
def batch_size_for_input_fn(self):
"""Returns the shard batch size for `input_fn`."""
global_batch_size = self.global_batch_size
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU
if self.is_input_sharded_per_core() or (
self.is_input_per_host_with_iterators()):
return global_batch_size // self.num_replicas
else:
return global_batch_size // self.num_hosts
@property
def batch_size_for_model_fn(self):
"""Returns the shard batch size for `model_fn`."""
global_batch_size = self.global_batch_size
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
return global_batch_size
# On TPU. always sharded per shard.
return global_batch_size // self.num_replicas
@property
def master_job(self):
"""Returns the job name to use to place TPU computations on.
Returns:
A string containing the job name, or None if no job should be specified.
Raises:
ValueError: If the user needs to specify a tpu_job_name, because we are
unable to infer the job name automatically, or if the user-specified job
names are inappropriate.
"""
run_config = self._config
# If the user specifies the tpu_job_name, use that.
if run_config.tpu_config.tpu_job_name:
return run_config.tpu_config.tpu_job_name
# The tpu job is determined by the run_config. Right now, this method is
# required as tpu_config is not part of the RunConfig.
mode = self._assert_mode()
master = (
run_config.evaluation_master
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
cluster_def = (run_config.session_config.cluster_def
if run_config.session_config else None)
return tpu_system_metadata_lib.master_job(master, cluster_def)
@property
def tpu_host_placement_function(self):
"""Returns the TPU host place function."""
master = self.master_job
def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
"""Return the host device given replica_id or host_id."""
assert _sentinal is None
if replica_id is not None and host_id is not None:
raise RuntimeError(
'replica_id and host_id can have only one non-None value.')
if master is None:
return '/replica:0/task:0/device:CPU:0'
else:
if replica_id is not None:
if self.model_parallelism_enabled:
return self.device_assignment.host_device(
replica=replica_id, job=master)
else:
host_id = replica_id / self.num_of_cores_per_host
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
return _placement_function
@property
def tpu_device_placement_function(self):
"""Returns a TPU device placement Fn."""
master = self.master_job
job_device = '' if master is None else ('/job:%s' % master)
def _placement_function(i):
if self.model_parallelism_enabled:
return self.device_assignment.tpu_device(replica=i, job=master)
else:
num_of_cores_per_host = self.num_of_cores_per_host
host_id = i / num_of_cores_per_host
ordinal_id = i % num_of_cores_per_host
return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
return _placement_function
def tpu_ordinal_function(self, host_id):
"""Returns the TPU ordinal fn."""
def _tpu_ordinal_function(shard_index_in_host):
"""Return the TPU ordinal associated with a shard.
Required because the enqueue ops are placed on CPU.
Args:
shard_index_in_host: the shard index
Returns:
The ordinal of the TPU device the shard's infeed should be placed on.
"""
if self.model_parallelism_enabled:
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
replica = self.device_assignment.lookup_replicas(host_id,
0)[shard_index_in_host]
return self.device_assignment.tpu_ordinal(replica=replica)
else:
return shard_index_in_host % self.num_of_cores_per_host
return _tpu_ordinal_function
def _validate_tpu_configuration(self):
"""Validates the configuration based on the TPU system metadata."""
mode = self._assert_mode()
if self._lazy_validation_dict.get(mode):
return
# All following information is obtained from TPU system metadata.
num_cores = self.num_cores
num_replicas = self.num_replicas
num_hosts = self.num_hosts
if not num_cores:
tpu_system_metadata = self._get_tpu_system_metadata()
raise RuntimeError(
'Cannot find any TPU cores in the system. Please double check '
'Tensorflow master address and TPU worker(s). Available devices '
'are {}.'.format(tpu_system_metadata.devices))
if self._config.tpu_config.num_shards:
user_provided_num_replicas = self._config.tpu_config.num_shards
if user_provided_num_replicas != num_replicas:
message = (
'TPUConfig.num_shards is not set correctly. According to TPU '
'system metadata for Tensorflow master ({}): num_replicas should '
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
raise ValueError(message)
if self._config.tpu_config.num_cores_per_replica:
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
if num_cores_per_replica > num_cores_per_host:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
'TPUConfig.num_cores_per_replica, is larger than the '
'num_cores_per_host. num_cores_per_replica: {}, '
'num_cores_per_host: {}'.format(num_cores_per_replica,
num_cores_per_host))
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'train batch size {} must be divisible by number of replicas {}'
.format(self._train_batch_size, num_replicas))
elif mode == model_fn_lib.ModeKeys.EVAL:
if self._eval_batch_size is None:
raise ValueError(
'eval_batch_size in TPUEstimator constructor cannot be `None`'
'if .evaluate is running on TPU.')
if (self._eval_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'eval batch size {} must be divisible by number of replicas {}'
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.evaluate should be running on single TPU'
' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
raise ValueError(
'predict_batch_size in TPUEstimator constructor should not be '
'`None` if .predict is running on TPU.')
if (self._predict_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
raise ValueError(
'predict batch size {} must be divisible by number of replicas {}'
.format(self._predict_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
'TPUEstimator.predict should be running on single TPU worker. '
'got {}.'.format(num_hosts))
# Record the state "validated" into lazy dictionary.
self._lazy_validation_dict[mode] = True
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
This should be used for full replicate for non-model-parallelism.
Args:
replica_id: Int, the replica index.
Returns:
A tuple of device spec for CPU device and int device ordinal.
"""
master = self.master_job
if self.model_parallelism_enabled:
return (self.device_assignment.host_device(
replica=replica_id, job=master),
self.device_assignment.tpu_ordinal(replica=replica_id))
job_device = '' if master is None else ('/job:%s' % master)
num_of_replicas_per_host = self.num_of_replicas_per_host
host_id = replica_id / num_of_replicas_per_host
ordinal_id = replica_id % num_of_replicas_per_host
host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
return (host_device, ordinal_id)
class _OneCoreTPUContext(_InternalTPUContext):
"""Special _InternalTPUContext for one core usage."""
def __init__(self, config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu):
super(_OneCoreTPUContext, self).__init__(
config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
def _get_tpu_system_metadata(self):
"""Gets the (maybe cached) TPU system metadata."""
master = self._get_master_address()
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
if tpu_system_metadata is not None:
return tpu_system_metadata
tpu_system_metadata = (
tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access
num_cores=1,
num_hosts=1,
num_of_cores_per_host=1,
topology=None,
devices=[]))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
def _get_tpu_context(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu,
embedding_config_spec):
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
config.tpu_config.num_cores_per_replica is None):
if embedding_config_spec is not None:
raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
'when embedding_config_spec is not None.')
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.)')
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
return _InternalTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu, eval_on_tpu,
embedding_config_spec)
# pylint: disable=wildcard-import,unused-import
from tensorflow_estimator.python.estimator.tpu.tpu_context import *
# pylint: enable=wildcard-import,unused-import

File diff suppressed because it is too large Load Diff

View File

@ -1,339 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU Estimator Signalling Tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_estimator
def make_input_fn(num_samples):
a = np.linspace(0, 100.0, num=num_samples)
b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
def input_fn(params):
batch_size = params['batch_size']
da1 = dataset_ops.Dataset.from_tensor_slices(a)
da2 = dataset_ops.Dataset.from_tensor_slices(b)
dataset = dataset_ops.Dataset.zip((da1, da2))
dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb})
dataset = dataset.batch(batch_size)
return dataset
return input_fn, (a, b)
def make_input_fn_with_labels(num_samples):
a = np.linspace(0, 100.0, num=num_samples)
b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
def input_fn(params):
batch_size = params['batch_size']
da1 = dataset_ops.Dataset.from_tensor_slices(a)
da2 = dataset_ops.Dataset.from_tensor_slices(b)
dataset = dataset_ops.Dataset.zip((da1, da2))
dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb))
dataset = dataset.batch(batch_size)
return dataset
return input_fn, (a, b)
class TPUEstimatorStoppingSignalsTest(test.TestCase):
def test_normal_output_without_signals(self):
num_samples = 4
batch_size = 2
params = {'batch_size': batch_size}
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
with ops.Graph().as_default():
dataset = input_fn(params)
features = dataset_ops.make_one_shot_iterator(dataset).get_next()
# With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
self.assertIsNone(features['a'].shape.as_list()[0])
with session.Session() as sess:
result = sess.run(features)
self.assertAllEqual(a[:batch_size], result['a'])
self.assertAllEqual(b[:batch_size], result['b'])
# This run should work as num_samples / batch_size = 2.
result = sess.run(features)
self.assertAllEqual(a[batch_size:num_samples], result['a'])
self.assertAllEqual(b[batch_size:num_samples], result['b'])
with self.assertRaises(errors.OutOfRangeError):
# Given num_samples and batch_size, this run should fail.
sess.run(features)
def test_output_with_stopping_signals(self):
num_samples = 4
batch_size = 2
params = {'batch_size': batch_size}
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
with ops.Graph().as_default():
dataset = input_fn(params)
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size)
dataset_initializer = inputs.dataset_initializer()
features, _ = inputs.features_and_labels()
signals = inputs.signals()
# With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
self.assertIsNone(features['a'].shape.as_list()[0])
with session.Session() as sess:
sess.run(dataset_initializer)
result, evaluated_signals = sess.run([features, signals])
self.assertAllEqual(a[:batch_size], result['a'])
self.assertAllEqual(b[:batch_size], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
# This run should work as num_samples / batch_size = 2.
result, evaluated_signals = sess.run([features, signals])
self.assertAllEqual(a[batch_size:num_samples], result['a'])
self.assertAllEqual(b[batch_size:num_samples], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
# This run should work, *but* see STOP ('1') as signals
_, evaluated_signals = sess.run([features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
with self.assertRaises(errors.OutOfRangeError):
sess.run(features)
class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
def test_num_samples_divisible_by_batch_size(self):
num_samples = 4
batch_size = 2
params = {'batch_size': batch_size}
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
with ops.Graph().as_default():
dataset = input_fn(params)
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
add_padding=True)
dataset_initializer = inputs.dataset_initializer()
features, _ = inputs.features_and_labels()
signals = inputs.signals()
# With padding, all shapes are static now.
self.assertEqual(batch_size, features['a'].shape.as_list()[0])
with session.Session() as sess:
sess.run(dataset_initializer)
result, evaluated_signals = sess.run([features, signals])
self.assertAllEqual(a[:batch_size], result['a'])
self.assertAllEqual(b[:batch_size], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([0.] * batch_size,
evaluated_signals['padding_mask'])
# This run should work as num_samples / batch_size = 2.
result, evaluated_signals = sess.run([features, signals])
self.assertAllEqual(a[batch_size:num_samples], result['a'])
self.assertAllEqual(b[batch_size:num_samples], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([0.] * batch_size,
evaluated_signals['padding_mask'])
# This run should work, *but* see STOP ('1') as signals
_, evaluated_signals = sess.run([features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
with self.assertRaises(errors.OutOfRangeError):
sess.run(features)
def test_num_samples_not_divisible_by_batch_size(self):
num_samples = 5
batch_size = 2
params = {'batch_size': batch_size}
input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples)
with ops.Graph().as_default():
dataset = input_fn(params)
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
add_padding=True)
dataset_initializer = inputs.dataset_initializer()
features, labels = inputs.features_and_labels()
signals = inputs.signals()
# With padding, all shapes are static.
self.assertEqual(batch_size, features['a'].shape.as_list()[0])
with session.Session() as sess:
sess.run(dataset_initializer)
evaluated_features, evaluated_labels, evaluated_signals = (
sess.run([features, labels, signals]))
self.assertAllEqual(a[:batch_size], evaluated_features['a'])
self.assertAllEqual(b[:batch_size], evaluated_labels)
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([0.] * batch_size,
evaluated_signals['padding_mask'])
# This run should work as num_samples / batch_size >= 2.
evaluated_features, evaluated_labels, evaluated_signals = (
sess.run([features, labels, signals]))
self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a'])
self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels)
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([0.] * batch_size,
evaluated_signals['padding_mask'])
# This is the final partial batch.
evaluated_features, evaluated_labels, evaluated_signals = (
sess.run([features, labels, signals]))
real_batch_size = num_samples % batch_size
# Assert the real part.
self.assertAllEqual(a[2*batch_size:num_samples],
evaluated_features['a'][:real_batch_size])
self.assertAllEqual(b[2*batch_size:num_samples],
evaluated_labels[:real_batch_size])
# Assert the padded part.
self.assertAllEqual([0.0] * (batch_size - real_batch_size),
evaluated_features['a'][real_batch_size:])
self.assertAllEqual([[0.0]] * (batch_size - real_batch_size),
evaluated_labels[real_batch_size:])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
padding = ([.0] * real_batch_size
+ [1.] * (batch_size - real_batch_size))
self.assertAllEqual(padding, evaluated_signals['padding_mask'])
# This run should work, *but* see STOP ('1') as signals
_, evaluated_signals = sess.run([features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
with self.assertRaises(errors.OutOfRangeError):
sess.run(features)
def test_slice(self):
num_samples = 3
batch_size = 2
params = {'batch_size': batch_size}
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
with ops.Graph().as_default():
dataset = input_fn(params)
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
add_padding=True)
dataset_initializer = inputs.dataset_initializer()
features, _ = inputs.features_and_labels()
signals = inputs.signals()
sliced_features = (
tpu_estimator._PaddingSignals.slice_tensor_or_dict(
features, signals))
with session.Session() as sess:
sess.run(dataset_initializer)
result, evaluated_signals = sess.run([sliced_features, signals])
self.assertAllEqual(a[:batch_size], result['a'])
self.assertAllEqual(b[:batch_size], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
# This is the final partial batch.
result, evaluated_signals = sess.run([sliced_features, signals])
self.assertEqual(1, len(result['a']))
self.assertAllEqual(a[batch_size:num_samples], result['a'])
self.assertAllEqual(b[batch_size:num_samples], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
# This run should work, *but* see STOP ('1') as signals
_, evaluated_signals = sess.run([sliced_features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
with self.assertRaises(errors.OutOfRangeError):
sess.run(sliced_features)
def test_slice_with_multi_invocations_per_step(self):
num_samples = 3
batch_size = 2
params = {'batch_size': batch_size}
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
with ops.Graph().as_default():
dataset = input_fn(params)
inputs = tpu_estimator._InputsWithStoppingSignals(
dataset, batch_size, add_padding=True, num_invocations_per_step=2)
dataset_initializer = inputs.dataset_initializer()
features, _ = inputs.features_and_labels()
signals = inputs.signals()
sliced_features = (
tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))
with session.Session() as sess:
sess.run(dataset_initializer)
result, evaluated_signals = sess.run([sliced_features, signals])
self.assertAllEqual(a[:batch_size], result['a'])
self.assertAllEqual(b[:batch_size], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
# This is the final partial batch.
result, evaluated_signals = sess.run([sliced_features, signals])
self.assertEqual(1, len(result['a']))
self.assertAllEqual(a[batch_size:num_samples], result['a'])
self.assertAllEqual(b[batch_size:num_samples], result['b'])
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
# We should see 3 continuous batches with STOP ('1') as signals and all
# of them have mask 1.
_, evaluated_signals = sess.run([sliced_features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([1.] * batch_size,
evaluated_signals['padding_mask'])
_, evaluated_signals = sess.run([sliced_features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([1.] * batch_size,
evaluated_signals['padding_mask'])
_, evaluated_signals = sess.run([sliced_features, signals])
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
self.assertAllEqual([1.] * batch_size,
evaluated_signals['padding_mask'])
with self.assertRaises(errors.OutOfRangeError):
sess.run(sliced_features)
if __name__ == '__main__':
test.main()

View File

@ -1,51 +1,23 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""Utilities for the functionalities."""
# ==============================================================================
"""Stub file to maintain backwards compatibility."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import six
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
def check_positive_integer(value, name):
"""Checks whether `value` is a positive integer."""
if not isinstance(value, six.integer_types):
raise TypeError('{} must be int, got {}'.format(name, type(value)))
if value <= 0:
raise ValueError('{} must be positive, got {}'.format(name, value))
# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we
# release a tensorflow_estimator with MultiHostDatasetInitializerHook in
# python/estimator/util.py.
class MultiHostDatasetInitializerHook(training.SessionRunHook):
"""Creates a SessionRunHook that initializes all passed iterators."""
def __init__(self, dataset_initializers):
self._initializers = dataset_initializers
def after_create_session(self, session, coord):
del coord
start = time.time()
session.run(self._initializers)
logging.info('Initialized dataset iterators in %d seconds',
time.time() - start)
# pylint: disable=wildcard-import,unused-import
from tensorflow_estimator.python.estimator.tpu.util import *
# pylint: enable=wildcard-import,unused-import