Delete the TF core copy of TPUEstimator in favor of the Estimator repo's
version. PiperOrigin-RevId: 247493942
This commit is contained in:
parent
c182167d3f
commit
72d36073e9
@ -19,5 +19,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu._tpu_estimator_embedding import *
|
||||
from tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -19,5 +19,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.error_handling import *
|
||||
from tensorflow_estimator.python.estimator.tpu.error_handling import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -19,5 +19,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_config import *
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_config import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -19,5 +19,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.tpu_context import *
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_context import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -19,15 +19,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import,redefined-builtin
|
||||
from tensorflow.python.tpu.tpu_estimator import *
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import *
|
||||
# used by tests
|
||||
from tensorflow.python.tpu.tpu_estimator import _clone_export_output_with_tensors
|
||||
from tensorflow.python.tpu.tpu_estimator import _create_global_step
|
||||
from tensorflow.python.tpu.tpu_estimator import _export_output_to_tensors
|
||||
from tensorflow.python.tpu.tpu_estimator import _get_scaffold
|
||||
from tensorflow.python.tpu.tpu_estimator import _Inputs
|
||||
from tensorflow.python.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR
|
||||
from tensorflow.python.tpu.tpu_estimator import _TPU_ENQUEUE_OPS
|
||||
from tensorflow.python.tpu.tpu_estimator import _TPU_ESTIMATOR
|
||||
from tensorflow.python.tpu.tpu_estimator import _TPU_TRAIN_OP
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _clone_export_output_with_tensors
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _create_global_step
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _export_output_to_tensors
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _get_scaffold
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _Inputs
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _ITERATIONS_PER_LOOP_VAR
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ENQUEUE_OPS
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_ESTIMATOR
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_estimator import _TPU_TRAIN_OP
|
||||
# pylint: enable=wildcard-import,unused-import,redefined-builtin
|
||||
|
@ -19,5 +19,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow.python.tpu.util import *
|
||||
from tensorflow_estimator.python.estimator.tpu.util import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -275,30 +275,6 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_config_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_config_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_estimator",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tpu_estimator_signals_test",
|
||||
size = "small",
|
||||
srcs = ["tpu_estimator_signals_test.py"],
|
||||
additional_deps = [
|
||||
":tpu_estimator",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
# TODO(jhseu): Remove. Fails in OSS on Python 3.
|
||||
tags = ["no_oss"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "topology_test",
|
||||
size = "medium",
|
||||
|
@ -1,382 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""Tooling for support TPU embedding in TPUEstimator."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.feature_column import feature_column as core_fc
|
||||
from tensorflow.python.feature_column import feature_column_lib as core_fc_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.tpu import feature_column as tpu_fc
|
||||
from tensorflow.python.tpu import tpu_embedding
|
||||
from tensorflow.python.tpu.tpu_embedding import AdagradParameters
|
||||
from tensorflow.python.tpu.tpu_embedding import AdamParameters
|
||||
from tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters
|
||||
from tensorflow.python.training import training
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,
|
||||
tpu_fc._TPUSharedEmbeddingColumn)
|
||||
_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,
|
||||
core_fc_lib.EmbeddingColumn,
|
||||
core_fc._SharedEmbeddingColumn)
|
||||
_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)
|
||||
_SUPPORTED_OPTIMIZERS = (AdagradParameters, AdamParameters,
|
||||
StochasticGradientDescentParameters)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
_TABLE_NAME_PREFIX = 'tbl_'
|
||||
_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)
|
||||
|
||||
|
||||
def _get_table_name_from_embedding_var_name(embedding_var_name):
|
||||
return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)
|
||||
|
||||
|
||||
def _get_embedding_var_name_from_table_name(table_name):
|
||||
return table_name[_LEN_TABLE_NAME_PREFIX:]
|
||||
|
||||
|
||||
def _get_embedding_variable_name(scope_name, var_name):
|
||||
return '{}/{}'.format(scope_name, var_name)
|
||||
|
||||
|
||||
def _get_slot_variable_names(scope_name, var_name, optimization_parameters):
|
||||
"""Return embedding variable names which are consistent with CPU runs."""
|
||||
if isinstance(optimization_parameters, tpu_embedding.AdagradParameters):
|
||||
return tpu_embedding.AdagradSlotVariableName(
|
||||
'{}/{}/Adagrad'.format(scope_name, var_name)
|
||||
)
|
||||
elif isinstance(optimization_parameters, tpu_embedding.AdamParameters):
|
||||
return tpu_embedding.AdamSlotVariableNames(
|
||||
'{}/{}/Adam/m'.format(scope_name, var_name),
|
||||
'{}/{}/Adam/v'.format(scope_name, var_name)
|
||||
)
|
||||
elif isinstance(optimization_parameters,
|
||||
tpu_embedding.StochasticGradientDescentParameters):
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Support to infer full variable name '
|
||||
'for optimization_parameter {} has not been added.'
|
||||
.format(optimization_parameters))
|
||||
|
||||
|
||||
def get_full_variable_names(
|
||||
graph, table_to_config_dict, optimization_parameters=None):
|
||||
"""Return embedding variable names and slot variables which are consistent with CPU runs."""
|
||||
collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
|
||||
if not collection:
|
||||
raise RuntimeError(
|
||||
'Embedding feature column did not capture any thing. Make sure the '
|
||||
'feature columns passed to TPUEstimator constructor is properly '
|
||||
'used in model_fn.')
|
||||
|
||||
embedding_variable_name_by_table = {}
|
||||
slot_variable_names_by_table = {}
|
||||
for table_name in table_to_config_dict:
|
||||
embedding_var_name = _get_embedding_var_name_from_table_name(table_name)
|
||||
(scope_name, var_name) = collection[0][embedding_var_name]
|
||||
embedding_variable_name_by_table[table_name] = (
|
||||
_get_embedding_variable_name(scope_name, var_name))
|
||||
if optimization_parameters:
|
||||
slot_variable_names_by_table[table_name] = _get_slot_variable_names(
|
||||
scope_name, var_name, optimization_parameters)
|
||||
|
||||
graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE) # pylint: disable=protected-access
|
||||
return embedding_variable_name_by_table, slot_variable_names_by_table
|
||||
|
||||
|
||||
def get_configs_from_feature_columns(feature_columns):
|
||||
"""Create configs for TPUEmbedding etc from a list of feature columns.
|
||||
|
||||
Args:
|
||||
feature_columns: a list of supported feature columns.
|
||||
|
||||
Returns:
|
||||
A tuple of dicts, the first maps tables to their config, the second maps
|
||||
features to their config, and the third maps features to weight key names.
|
||||
"""
|
||||
|
||||
allowed = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn) # pylint: disable=protected-access
|
||||
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, allowed):
|
||||
raise TypeError(
|
||||
'Unsupported feature column {}. Supported types are {}.'.format(
|
||||
type(column), allowed))
|
||||
|
||||
table_to_config = {}
|
||||
feature_to_config = {}
|
||||
feature_to_weight_key_name = {}
|
||||
for column in feature_columns:
|
||||
feature_name = column.get_feature_key_name()
|
||||
table_name = _get_table_name_from_embedding_var_name(
|
||||
column.get_embedding_var_name())
|
||||
if feature_name in feature_to_config:
|
||||
raise ValueError(
|
||||
'Feature column {} is used with multiple embeddings and this is '
|
||||
'not supported.'.format(feature_name))
|
||||
feature_to_config[feature_name] = tpu_embedding.FeatureConfig(
|
||||
table_id=table_name)
|
||||
feature_to_weight_key_name[feature_name] = column.get_weight_key_name()
|
||||
vocabulary_size, dimension = column.get_embedding_table_size()
|
||||
table_to_config[table_name] = tpu_embedding.TableConfig(
|
||||
vocabulary_size=vocabulary_size,
|
||||
dimension=dimension,
|
||||
initializer=column.get_initializer(),
|
||||
combiner=column.get_combiner())
|
||||
|
||||
return table_to_config, feature_to_config, feature_to_weight_key_name
|
||||
|
||||
|
||||
class EmbeddingConfigSpec(
|
||||
collections.namedtuple('EmbeddingConfigSpec', [
|
||||
'feature_columns', 'optimization_parameters', 'clipping_limit',
|
||||
'pipeline_execution_with_tensor_core',
|
||||
'experimental_gradient_multiplier_fn'
|
||||
])):
|
||||
"""Class to keep track of embedding config specification."""
|
||||
|
||||
def __new__(cls,
|
||||
feature_columns,
|
||||
optimization_parameters,
|
||||
clipping_limit=None,
|
||||
pipeline_execution_with_tensor_core=False,
|
||||
experimental_gradient_multiplier_fn=None):
|
||||
"""Creates an EmbeddingConfigSpec instance.
|
||||
|
||||
Args:
|
||||
feature_columns: All `FeatureColumn`s used by model.
|
||||
optimization_parameters: An instance of `AdagradParameters`,
|
||||
`AdamParameters` or `StochasticGradientDescentParameters`. This
|
||||
optimizer will be applied to all embedding variables specified by
|
||||
`feature_columns`.
|
||||
clipping_limit: (Optional) Clipping limit (absolute value).
|
||||
pipeline_execution_with_tensor_core: setting this to `True` makes training
|
||||
faster, but trained model will be different if step N and step N+1
|
||||
involve the same set of embedding IDs. Please see
|
||||
`tpu_embedding_configuration.proto` for details.
|
||||
experimental_gradient_multiplier_fn: (Optional) A Fn taking global step as
|
||||
input returning the current multiplier for all embedding gradients.
|
||||
|
||||
Returns:
|
||||
An EmbeddingConfigSpec instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the feature_columns are not specified.
|
||||
TypeError: If the feature columns are not of ths correct type (one of
|
||||
_SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR
|
||||
_EMBEDDING_COLUMN_CLASSES).
|
||||
ValueError: If `optimization_parameters` is not one of the required types.
|
||||
"""
|
||||
if not feature_columns:
|
||||
raise ValueError('`feature_columns` cannot be `None` or empty.')
|
||||
|
||||
# It is unknown at this moment, whether the TPUEstimator is running in CPU
|
||||
# or TPU mode. So allow non-TPU embedding columns also.
|
||||
supported_classes = tuple(
|
||||
list(_SUPPORTED_FEATURE_COLUMNS) + list(_TPU_EMBEDDING_COLUMN_CLASSES) +
|
||||
list(_EMBEDDING_COLUMN_CLASSES))
|
||||
|
||||
for column in feature_columns:
|
||||
if not isinstance(column, supported_classes):
|
||||
raise TypeError(
|
||||
'All feature columns must be supported types in {}. Got {}'.format(
|
||||
supported_classes, type(column)))
|
||||
|
||||
if not isinstance(optimization_parameters, _SUPPORTED_OPTIMIZERS):
|
||||
raise ValueError('optimization_parameters must be an instance of type '
|
||||
'{}. Got {}.'.format(_SUPPORTED_OPTIMIZERS,
|
||||
type(optimization_parameters)))
|
||||
|
||||
return super(EmbeddingConfigSpec, cls).__new__(
|
||||
cls,
|
||||
feature_columns=feature_columns,
|
||||
optimization_parameters=optimization_parameters,
|
||||
clipping_limit=clipping_limit,
|
||||
pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core,
|
||||
experimental_gradient_multiplier_fn=experimental_gradient_multiplier_fn)
|
||||
|
||||
|
||||
class EmbeddingConfig(object):
|
||||
"""This is the internal immutable object for embedding config.
|
||||
|
||||
`_EmbeddingConfig` is responsible to _translate_ user provided
|
||||
`EmbeddingConfigSpec` to internal data structures, mostly constructor
|
||||
arguments of `TPUEmbedding`.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
|
||||
num_hosts, num_cores, run_config):
|
||||
if not embedding_config_spec:
|
||||
raise ValueError('embedding_config_spec cannot be None.')
|
||||
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._num_hosts = num_hosts
|
||||
self._num_cores = num_cores
|
||||
self._run_config = run_config
|
||||
|
||||
(self._table_to_config_dict, self._feature_to_config_dict,
|
||||
self.feature_to_weight_key_name_dict) = (
|
||||
get_configs_from_feature_columns(
|
||||
embedding_config_spec.feature_columns))
|
||||
self._mode_to_tpu_embedding_dict = {}
|
||||
self.dummy_table_variables = None
|
||||
|
||||
self._grad_multiplier_fn = (
|
||||
embedding_config_spec.experimental_gradient_multiplier_fn)
|
||||
|
||||
def get_grad_multiplier(self):
|
||||
if self._grad_multiplier_fn:
|
||||
return ops.convert_to_tensor(
|
||||
self._grad_multiplier_fn(training.get_global_step()),
|
||||
dtype=dtypes.float32)
|
||||
|
||||
def has_embedding_tables(self):
|
||||
return bool(self._table_to_config_dict)
|
||||
|
||||
def _create_tpu_embedding(self, mode):
|
||||
"""Create tpu_embedding.TPUEmbedding based on mode."""
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
batch_size = self._train_batch_size
|
||||
else:
|
||||
batch_size = self._eval_batch_size
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
tpu_embedding_mode = tpu_embedding.TRAINING
|
||||
optimization_parameters = (
|
||||
self._embedding_config_spec.optimization_parameters)
|
||||
elif (mode == model_fn_lib.ModeKeys.EVAL or
|
||||
mode == model_fn_lib.ModeKeys.PREDICT):
|
||||
tpu_embedding_mode = tpu_embedding.INFERENCE
|
||||
optimization_parameters = None
|
||||
else:
|
||||
raise ValueError('Mode {} is not supported.'.format(mode))
|
||||
|
||||
if self._run_config.cluster:
|
||||
master = self._run_config.cluster.master()
|
||||
cluster_spec = self._run_config.cluster.cluster_spec()
|
||||
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
|
||||
else:
|
||||
master = (
|
||||
self._run_config.evaluation_master
|
||||
if mode == model_fn_lib.ModeKeys.EVAL else self._run_config.master)
|
||||
cluster_def = None
|
||||
tpu_embedding_ = tpu_embedding.TPUEmbedding(
|
||||
self._table_to_config_dict,
|
||||
self._feature_to_config_dict,
|
||||
batch_size,
|
||||
tpu_embedding_mode,
|
||||
master,
|
||||
optimization_parameters,
|
||||
cluster_def,
|
||||
pipeline_execution_with_tensor_core=self._embedding_config_spec
|
||||
.pipeline_execution_with_tensor_core)
|
||||
return tpu_embedding_
|
||||
|
||||
def get_tpu_embedding(self, mode):
|
||||
if mode not in self._mode_to_tpu_embedding_dict:
|
||||
self._mode_to_tpu_embedding_dict[mode] = (
|
||||
self._create_tpu_embedding(mode))
|
||||
return self._mode_to_tpu_embedding_dict[mode]
|
||||
|
||||
|
||||
def split_inputs(ctx, features, labels):
|
||||
"""Splits the dense and sparse tensors inside the features and labels."""
|
||||
enqueue_datas = collections.OrderedDict()
|
||||
if ctx.embedding_config:
|
||||
tpu_embedding_ = ctx.embedding_config.tpu_embedding
|
||||
feature_to_weight_key_name_dict = (
|
||||
ctx.embedding_config.feature_to_weight_key_name_dict)
|
||||
for feature_key in tpu_embedding_.feature_to_config_dict:
|
||||
sparse_feature = _get_sparse_feature_from_feature(feature_key, features)
|
||||
weight_key_name = feature_to_weight_key_name_dict[feature_key]
|
||||
if isinstance(sparse_feature, sparse_tensor.SparseTensor):
|
||||
weights = _get_weights_from_features(weight_key_name, features)
|
||||
enqueue_data = tpu_embedding.EnqueueData.from_sparse_tensor(
|
||||
sparse_feature, weights)
|
||||
else:
|
||||
if weight_key_name is not None:
|
||||
raise ValueError(
|
||||
'Found weights {} for weighted_categorical_column, which is not'
|
||||
'compatible with sparse feature {} enqueued as dense tensor.'
|
||||
.format(weight_key_name, feature_key))
|
||||
enqueue_data = tpu_embedding.EnqueueData(sparse_feature)
|
||||
enqueue_datas[feature_key] = enqueue_data
|
||||
|
||||
return features, labels, enqueue_datas
|
||||
|
||||
|
||||
def _get_sparse_feature_from_feature(feature_key, features):
|
||||
"""Pop and return sparse feature."""
|
||||
sparse_feature = features.pop(feature_key)
|
||||
if not sparse_feature.dtype.is_integer:
|
||||
raise ValueError('SparseTensor with string as values are not supported. '
|
||||
'If you are using vocabulary_file_categorical_column or '
|
||||
'vocabulary_list_categorical_column, please call '
|
||||
'your_column.categorical_column._transform_feature({{'
|
||||
'your_column.key: features[your_column.key]}}) in'
|
||||
'your input_fn() to convert string to int. '
|
||||
'feature_key = {}.'.format(feature_key))
|
||||
return sparse_feature
|
||||
|
||||
|
||||
def _get_weights_from_features(weight_key_name, features):
|
||||
"""Pop and return feature for weights, possibly None."""
|
||||
weights = None
|
||||
if weight_key_name is not None:
|
||||
if weight_key_name in features:
|
||||
weights = features.pop(weight_key_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
'Cannot find weights {} for weighted_categorical_column.'
|
||||
' Please check if the weights are present in feature dict. Also'
|
||||
' note weight-sharing among weighted_categorical_column is not '
|
||||
'supported on TPU.'.format(weight_key_name))
|
||||
if not isinstance(weights, sparse_tensor.SparseTensor):
|
||||
raise ValueError(
|
||||
'weighted_categorical_column with weight key name {} has dense '
|
||||
'weights. Dense weights are not supported on TPU. Please use '
|
||||
'sparse weights instead.'.format(weight_key_name))
|
||||
if weights.dtype is not dtypes.float32:
|
||||
weights = math_ops.to_float(weights)
|
||||
return weights
|
||||
|
||||
|
||||
def get_tpu_embedding_columns(feature_columns):
|
||||
"""Get feature columns meant to use TPU embedding.
|
||||
|
||||
Args:
|
||||
feature_columns: a list of feature columns.
|
||||
|
||||
Returns:
|
||||
A list of feature columns which can be placed on TPU embedding.
|
||||
"""
|
||||
tpu_embedding_columns = []
|
||||
for column in feature_columns:
|
||||
if isinstance(column, _TPU_EMBEDDING_COLUMN_CLASSES):
|
||||
tpu_embedding_columns.append(column)
|
||||
return tpu_embedding_columns
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,135 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""ErrorRendezvous handler for collecting errors from multiple threads."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
_UNINTERESTING_ERRORS = (errors.CancelledError,)
|
||||
|
||||
|
||||
class ErrorRendezvous(object):
|
||||
"""Resolve errors from multiple threads during TPU execution.
|
||||
|
||||
TPU errors can occur on the infeed or outfeed threads as well as the main
|
||||
training thread.
|
||||
|
||||
Depending on which thread "wins" and receives the session error first, we may
|
||||
end up showing users a confusing and non-actionable error message (session
|
||||
cancelled) instead of a root cause (e.g. a bad filename).
|
||||
|
||||
The rendezvous object provides a location to capture these errors until all
|
||||
threads terminate. At that point we can choose the most informative error
|
||||
to report.
|
||||
"""
|
||||
|
||||
def __init__(self, num_sources):
|
||||
# string -> (message, traceback)
|
||||
self._errors = {}
|
||||
self._num_sources = num_sources
|
||||
self._session_cancel_timer = None
|
||||
|
||||
def record_error(self, source, exc_info, session=None):
|
||||
"""Report an exception from the given source.
|
||||
|
||||
If a session is passed, a timer will be registered to close it after a few
|
||||
seconds. This is necessary to ensure the main training loop does not hang
|
||||
if an infeed/oufeed error occurs. We sleep a few seconds to allow a more
|
||||
interesting error from another thread to propagate.
|
||||
|
||||
Args:
|
||||
source: string, source of the error
|
||||
exc_info: Output from `sys.exc_info` (type, value, traceback)
|
||||
session: Session to close after delay.
|
||||
"""
|
||||
_, value, _ = exc_info
|
||||
self._errors[source] = exc_info
|
||||
logging.error('Error recorded from %s: %s', source, value)
|
||||
|
||||
if session is not None and self._session_cancel_timer is None:
|
||||
|
||||
def _cancel_session():
|
||||
time.sleep(5)
|
||||
logging.error('Closing session due to error %s' % value)
|
||||
try:
|
||||
session.close()
|
||||
except: # pylint: disable=bare-except
|
||||
logging.error(
|
||||
'\n\n\nFailed to close session after error.'
|
||||
'Other threads may hang.\n\n\n')
|
||||
|
||||
self._session_cancel_timer = threading.Thread(target=_cancel_session,)
|
||||
self._session_cancel_timer.daemon = True
|
||||
self._session_cancel_timer.start()
|
||||
|
||||
def record_done(self, source):
|
||||
"""Mark execution source `source` as done.
|
||||
|
||||
If an error was originally reported from `source` it is left intact.
|
||||
|
||||
Args:
|
||||
source: `str`, source being recorded
|
||||
"""
|
||||
logging.info('%s marked as finished', source)
|
||||
if source not in self._errors:
|
||||
self._errors[source] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def catch_errors(self, source, session=None):
|
||||
"""Context manager to report any errors within a block."""
|
||||
try:
|
||||
yield
|
||||
except Exception: # pylint: disable=broad-except
|
||||
self.record_error(source, sys.exc_info(), session)
|
||||
|
||||
def raise_errors(self, timeout_sec=0):
|
||||
"""Wait for up to `timeout` seconds for all error sources to finish.
|
||||
|
||||
Preferentially raise "interesting" errors (errors not in the
|
||||
_UNINTERESTING_ERRORS) set.
|
||||
|
||||
Args:
|
||||
timeout_sec: Seconds to wait for other error sources.
|
||||
"""
|
||||
for _ in range(timeout_sec):
|
||||
if len(self._errors) == self._num_sources:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
|
||||
|
||||
# First check for any interesting errors, then fall back on the session
|
||||
# cancelled errors etc.
|
||||
for k, (typ, value, traceback) in kept_errors:
|
||||
if isinstance(value, _UNINTERESTING_ERRORS):
|
||||
continue
|
||||
else:
|
||||
logging.warn('Reraising captured error')
|
||||
six.reraise(typ, value, traceback)
|
||||
|
||||
for k, (typ, value, traceback) in kept_errors:
|
||||
logging.warn('Reraising captured error')
|
||||
six.reraise(typ, value, traceback)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow_estimator.python.estimator.tpu.error_handling import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,295 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""A RunConfig subclass with TPU support."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.estimator import run_config as run_config_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import util as util_lib
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
|
||||
_SERVICE_KEY = run_config_lib._SERVICE_KEY
|
||||
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class InputPipelineConfig(object):
|
||||
r"""Please see the definition of these values in TPUConfig."""
|
||||
PER_SHARD_V1 = 1
|
||||
PER_HOST_V1 = 2
|
||||
PER_HOST_V2 = 3
|
||||
BROADCAST = 4
|
||||
SLICED = 5
|
||||
|
||||
|
||||
class TPUConfig(
|
||||
collections.namedtuple('TPUConfig', [
|
||||
'iterations_per_loop',
|
||||
'num_shards',
|
||||
'num_cores_per_replica',
|
||||
'per_host_input_for_training',
|
||||
'tpu_job_name',
|
||||
'initial_infeed_sleep_secs',
|
||||
'input_partition_dims',
|
||||
'eval_training_input_configuration',
|
||||
])):
|
||||
r"""TPU related configuration required by `TPUEstimator`.
|
||||
|
||||
Args:
|
||||
iterations_per_loop: This is the number of train steps running in TPU
|
||||
system before returning to CPU host for each `Session.run`. This means
|
||||
global step is increased `iterations_per_loop` times in one `Session.run`.
|
||||
It is recommended to be set as number of global steps for next checkpoint.
|
||||
Note that in evaluation don't use this value, instead we run total eval
|
||||
`steps` on TPU for a single `Session.run`.
|
||||
num_shards: (Deprecated, ignored by TPUEstimator).
|
||||
The number of model replicas in the system. For non-model-parallelism
|
||||
case, this number equals the total number of TPU cores. For
|
||||
model-parallelism, the total number of TPU cores equals
|
||||
num_cores_per_replica * num_shards.
|
||||
num_cores_per_replica: Defaults to `None`, which disables model parallelism.
|
||||
An integer which describes the number of TPU cores per model replica. This
|
||||
is required by model-parallelism which enables partitioning
|
||||
the model to multiple cores. Currently num_cores_per_replica must be
|
||||
1, 2, 4, or 8.
|
||||
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
|
||||
`input_fn` is invoked once on each host. With the per-core input pipeline
|
||||
configuration, it is invoked once for each core.
|
||||
With a global batch size `train_batch_size` in `TPUEstimator` constructor,
|
||||
the batch size for each shard is `train_batch_size` // #hosts in the
|
||||
`True` or `PER_HOST_V1` mode. In `PER_HOST_V2` mode, it is
|
||||
`train_batch_size` // #cores. In `BROADCAST` mode, `input_fn` is only
|
||||
invoked once on host 0 and the tensors are broadcasted to all other
|
||||
replicas. The batch size equals to `train_batch_size`. With the per-core
|
||||
input pipeline configuration, the shard batch size is also
|
||||
`train_batch_size` // #cores.
|
||||
Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.
|
||||
tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
|
||||
within TPUEstimator, however when using ClusterSpec propagation in more
|
||||
esoteric cluster configurations, you may need to specify the job name as a
|
||||
string.
|
||||
initial_infeed_sleep_secs: The number of seconds the infeed thread should
|
||||
wait before enqueueing the first batch. This helps avoid timeouts for
|
||||
models that require a long compilation time.
|
||||
input_partition_dims: A nested list to describe the partition dims
|
||||
for all the tensors from input_fn(). The structure of
|
||||
input_partition_dims must match the structure of `features` and
|
||||
`labels` from input_fn(). The total number of partitions must match
|
||||
`num_cores_per_replica`. For example, if input_fn() returns two tensors:
|
||||
images with shape [N, H, W, C] and labels [N].
|
||||
input_partition_dims = [[1, 2, 2, 1], None] will split the images to 4
|
||||
pieces and feed into 4 TPU cores. labels tensor are directly broadcasted
|
||||
to all the TPU cores since the partition dims is `None`.
|
||||
Current limitations: This feature is only supported with the PER_HOST_V2
|
||||
input mode.
|
||||
eval_training_input_configuration: If `SLICED`, `input_fn` is only
|
||||
invoked once on host 0 and the tensors are broadcasted to all other
|
||||
replicas. Unlike per_host_input_for_training=BROADCAST, each replica will
|
||||
only get a slice of the data instead of a whole copy. If `PER_HOST_V1`,
|
||||
the behaviour is determined by per_host_input_for_training.
|
||||
|
||||
Raises:
|
||||
ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
iterations_per_loop=2,
|
||||
num_shards=None,
|
||||
num_cores_per_replica=None,
|
||||
per_host_input_for_training=True,
|
||||
tpu_job_name=None,
|
||||
initial_infeed_sleep_secs=None,
|
||||
input_partition_dims=None,
|
||||
eval_training_input_configuration=InputPipelineConfig.PER_HOST_V1):
|
||||
|
||||
# Check iterations_per_loop.
|
||||
util_lib.check_positive_integer(iterations_per_loop,
|
||||
'TPUConfig iterations_per_loop')
|
||||
|
||||
# Check num_shards.
|
||||
if num_shards is not None:
|
||||
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
|
||||
|
||||
if input_partition_dims is not None:
|
||||
if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:
|
||||
raise ValueError(
|
||||
'input_partition_dims must be a list/tuple with one or two'
|
||||
' elements.')
|
||||
|
||||
if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:
|
||||
raise ValueError(
|
||||
'input_partition_dims is only supported in PER_HOST_V2 mode.')
|
||||
|
||||
if num_cores_per_replica is None:
|
||||
raise ValueError(
|
||||
'input_partition_dims requires setting num_cores_per_replica.')
|
||||
|
||||
# Check num_cores_per_replica
|
||||
if num_cores_per_replica is not None:
|
||||
if num_cores_per_replica not in [1, 2, 4, 8, 16]:
|
||||
raise ValueError(
|
||||
'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
|
||||
str(num_cores_per_replica)))
|
||||
|
||||
if eval_training_input_configuration not in [
|
||||
InputPipelineConfig.PER_HOST_V1, InputPipelineConfig.SLICED
|
||||
]:
|
||||
raise ValueError(
|
||||
'eval_training_input_configuration must be PER_HOST_V1 or SLICED;'
|
||||
' got {}'.format(str(eval_training_input_configuration)))
|
||||
|
||||
# per_host_input_for_training may be True, False, or integer in [1..3].
|
||||
# Map legacy values (True, False) to numeric values.
|
||||
if per_host_input_for_training is False:
|
||||
per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1
|
||||
elif per_host_input_for_training is True:
|
||||
per_host_input_for_training = InputPipelineConfig.PER_HOST_V1
|
||||
|
||||
# Check initial_infeed_sleep_secs.
|
||||
if initial_infeed_sleep_secs:
|
||||
util_lib.check_positive_integer(initial_infeed_sleep_secs,
|
||||
'TPUConfig initial_infeed_sleep_secs')
|
||||
|
||||
tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()
|
||||
|
||||
return super(TPUConfig, cls).__new__(
|
||||
cls,
|
||||
iterations_per_loop=iterations_per_loop,
|
||||
num_shards=num_shards,
|
||||
num_cores_per_replica=num_cores_per_replica,
|
||||
per_host_input_for_training=per_host_input_for_training,
|
||||
tpu_job_name=tpu_job_name,
|
||||
initial_infeed_sleep_secs=initial_infeed_sleep_secs,
|
||||
input_partition_dims=input_partition_dims,
|
||||
eval_training_input_configuration=eval_training_input_configuration)
|
||||
|
||||
|
||||
class RunConfig(run_config_lib.RunConfig):
|
||||
"""RunConfig with TPU support."""
|
||||
|
||||
def __init__(self,
|
||||
tpu_config=None,
|
||||
evaluation_master=None,
|
||||
master=None,
|
||||
cluster=None,
|
||||
**kwargs):
|
||||
"""Constructs a RunConfig.
|
||||
|
||||
Args:
|
||||
tpu_config: the TPUConfig that specifies TPU-specific configuration.
|
||||
evaluation_master: a string. The address of the master to use for eval.
|
||||
Defaults to master if not set.
|
||||
master: a string. The address of the master to use for training.
|
||||
cluster: a ClusterResolver
|
||||
**kwargs: keyword config parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: if cluster is not None and the provided session_config has a
|
||||
cluster_def already.
|
||||
"""
|
||||
super(RunConfig, self).__init__(**kwargs)
|
||||
self._tpu_config = tpu_config or TPUConfig()
|
||||
self._cluster = cluster
|
||||
|
||||
# If user sets master and/or evaluation_master explicitly, including empty
|
||||
# string '', take it. Otherwise, take the values set by parent class.
|
||||
if master is not None:
|
||||
if cluster is not None:
|
||||
raise ValueError('Both master and cluster are set.')
|
||||
self._master = master
|
||||
else:
|
||||
if cluster:
|
||||
self._master = cluster.master()
|
||||
|
||||
if evaluation_master is not None:
|
||||
self._evaluation_master = evaluation_master
|
||||
elif (not self._evaluation_master and
|
||||
self.task_type != run_config_lib.TaskType.EVALUATOR):
|
||||
# If the task type is EVALUATOR, it means some cluster manager sets the
|
||||
# TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.
|
||||
#
|
||||
# Otherwise, it means user executes the code without external cluster
|
||||
# manager. For that, we optimize the user experience by setting
|
||||
# evaluation_master to master, unless user overwrites it.
|
||||
self._evaluation_master = self._master
|
||||
|
||||
# Set the ClusterSpec to use
|
||||
if cluster:
|
||||
self._cluster_spec = cluster.cluster_spec()
|
||||
|
||||
# Merge the cluster_def into the ConfigProto.
|
||||
if self._session_config is None: # pylint: disable=access-member-before-definition
|
||||
self._session_config = config_pb2.ConfigProto(
|
||||
allow_soft_placement=True, isolate_session_state=True)
|
||||
if self._session_config.HasField('cluster_def'):
|
||||
raise ValueError(
|
||||
'You cannot provide a ClusterResolver and '
|
||||
'session_config.cluster_def.')
|
||||
if self._cluster_spec:
|
||||
self._session_config.cluster_def.CopyFrom(
|
||||
self._cluster_spec.as_cluster_def())
|
||||
|
||||
def _maybe_overwrite_session_config_for_distributed_training(self):
|
||||
# Overrides the parent class session_config overwrite for between-graph. TPU
|
||||
# runs with in-graph, which should not have device filter. Doing nothing
|
||||
# ("pass") basically disables it.
|
||||
pass
|
||||
|
||||
@property
|
||||
def evaluation_master(self):
|
||||
return self._evaluation_master
|
||||
|
||||
@property
|
||||
def master(self):
|
||||
return self._master
|
||||
|
||||
@property
|
||||
def tpu_config(self):
|
||||
return self._tpu_config
|
||||
|
||||
@property
|
||||
def cluster(self):
|
||||
return self._cluster
|
||||
|
||||
def replace(self, **kwargs):
|
||||
if 'tpu_config' not in kwargs:
|
||||
return super(RunConfig, self).replace(**kwargs)
|
||||
|
||||
tpu_config = kwargs.pop('tpu_config')
|
||||
new_instance = super(RunConfig, self).replace(**kwargs)
|
||||
new_instance._tpu_config = tpu_config # pylint: disable=protected-access
|
||||
return new_instance
|
||||
|
||||
|
||||
def _get_tpu_job_name_from_tf_config():
|
||||
"""Extracts the TPU job name from TF_CONFIG env variable."""
|
||||
# TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster
|
||||
# spec propagation.
|
||||
tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
|
||||
tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)
|
||||
if tpu_job_name:
|
||||
logging.info('Load TPU job name from TF_CONFIG: %s', tpu_job_name)
|
||||
return tpu_job_name
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_config import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
@ -1,181 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TPU RunConfig tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.estimator import run_config as run_config_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_config as tpu_config_lib
|
||||
|
||||
|
||||
def _set_tf_config_env_variable(tf_config):
|
||||
return test.mock.patch.dict('os.environ', {
|
||||
'TF_CONFIG': json.dumps(tf_config)
|
||||
})
|
||||
|
||||
|
||||
class TPURunConfigTest(test.TestCase):
|
||||
|
||||
def test_no_session_config_set_in_local_case(self):
|
||||
run_config = tpu_config_lib.RunConfig()
|
||||
self.assertIsNone(run_config.session_config)
|
||||
|
||||
def test_no_session_config_overwrite_in_local_case(self):
|
||||
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
|
||||
run_config = tpu_config_lib.RunConfig(session_config=session_config)
|
||||
self.assertEqual(session_config, run_config.session_config)
|
||||
|
||||
def test_no_session_config_set_with_cluster_spec(self):
|
||||
tf_config = {
|
||||
'cluster': {
|
||||
run_config_lib.TaskType.CHIEF: ['host3:3'],
|
||||
run_config_lib.TaskType.WORKER: ['host3:4']
|
||||
},
|
||||
'task': {
|
||||
'type': run_config_lib.TaskType.CHIEF,
|
||||
'index': 0
|
||||
}
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
run_config = tpu_config_lib.RunConfig()
|
||||
self.assertIsNone(run_config.session_config)
|
||||
|
||||
def test_no_session_config_overwrite_with_cluster_spec(self):
|
||||
tf_config = {
|
||||
'cluster': {
|
||||
run_config_lib.TaskType.CHIEF: ['host3:3'],
|
||||
run_config_lib.TaskType.WORKER: ['host3:4']
|
||||
},
|
||||
'task': {
|
||||
'type': run_config_lib.TaskType.CHIEF,
|
||||
'index': 0
|
||||
}
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
|
||||
run_config = tpu_config_lib.RunConfig(session_config=session_config)
|
||||
self.assertEqual(session_config, run_config.session_config)
|
||||
|
||||
def test_fail_with_invalid_num_shards(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'must be positive'):
|
||||
tpu_config_lib.RunConfig(
|
||||
tpu_config=tpu_config_lib.TPUConfig(num_shards=0))
|
||||
|
||||
def test_fail_with_iterations_per_loop(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'must be positive'):
|
||||
tpu_config_lib.RunConfig(
|
||||
tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
|
||||
|
||||
def test_fail_with_invalid_num_cores_per_replica(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
|
||||
' got 7'):
|
||||
tpu_config_lib.TPUConfig(num_cores_per_replica=7)
|
||||
|
||||
|
||||
class TPURunConfigMasterTest(test.TestCase):
|
||||
|
||||
def test_default_values(self):
|
||||
run_config = tpu_config_lib.RunConfig()
|
||||
self.assertEqual('', run_config.master)
|
||||
self.assertEqual('', run_config.evaluation_master)
|
||||
|
||||
def test_user_provided_master_and_evaluation_master(self):
|
||||
run_config = tpu_config_lib.RunConfig(
|
||||
master='_master_123', evaluation_master='_eval_master_123')
|
||||
self.assertEqual('_master_123', run_config.master)
|
||||
self.assertEqual('_eval_master_123', run_config.evaluation_master)
|
||||
|
||||
def test_evaluation_master_defaults_to_master(self):
|
||||
run_config = tpu_config_lib.RunConfig(master='_master_123')
|
||||
self.assertEqual('_master_123', run_config.master)
|
||||
self.assertEqual('_master_123', run_config.evaluation_master)
|
||||
|
||||
def test_tf_config(self):
|
||||
tf_config = {
|
||||
'session_master': '_master_123',
|
||||
'eval_session_master': '_eval_master_123'
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
run_config = tpu_config_lib.RunConfig()
|
||||
self.assertEqual('_master_123', run_config.master)
|
||||
self.assertEqual('_eval_master_123', run_config.evaluation_master)
|
||||
|
||||
def test_evaluation_master_defaults_to_master_in_tf_config(self):
|
||||
tf_config = {
|
||||
'session_master': '_master_123',
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
run_config = tpu_config_lib.RunConfig()
|
||||
self.assertEqual('_master_123', run_config.master)
|
||||
self.assertEqual('_master_123', run_config.evaluation_master)
|
||||
|
||||
def test_respect_evaluation_master_in_tf_config(self):
|
||||
tf_config = {
|
||||
'cluster': {
|
||||
run_config_lib.TaskType.CHIEF: ['host0:0'],
|
||||
},
|
||||
'task': {
|
||||
'type': run_config_lib.TaskType.EVALUATOR,
|
||||
'index': 0
|
||||
},
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
run_config = tpu_config_lib.RunConfig(master='_something')
|
||||
self.assertEqual('', run_config.evaluation_master)
|
||||
|
||||
def test_user_overwrites_tf_config(self):
|
||||
tf_config = {
|
||||
'session_master': '_master_123',
|
||||
'eval_session_master': '_eval_master_123'
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
run_config = tpu_config_lib.RunConfig(
|
||||
master='_new_master_123', evaluation_master='_new_eval_master_123')
|
||||
self.assertEqual('_new_master_123', run_config.master)
|
||||
self.assertEqual('_new_eval_master_123', run_config.evaluation_master)
|
||||
|
||||
def test_user_overwrites_master_in_tf_config(self):
|
||||
tf_config = {
|
||||
'session_master': '_master_123',
|
||||
'eval_session_master': '_eval_master_123'
|
||||
}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
run_config = tpu_config_lib.RunConfig(master='_new_master_123')
|
||||
self.assertEqual('_new_master_123', run_config.master)
|
||||
self.assertEqual('_eval_master_123', run_config.evaluation_master)
|
||||
|
||||
|
||||
class TPUJobNameTest(test.TestCase):
|
||||
|
||||
def test_default_name(self):
|
||||
config = tpu_config_lib.RunConfig()
|
||||
self.assertIsNone(config.tpu_config.tpu_job_name)
|
||||
|
||||
def test_with_tf_config(self):
|
||||
tf_config = {'service': {'tpu_worker_job_name': '_my_new_name',}}
|
||||
with _set_tf_config_env_variable(tf_config):
|
||||
config = tpu_config_lib.RunConfig()
|
||||
self.assertEqual('_my_new_name', config.tpu_config.tpu_job_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,749 +1,23 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
"""TPU system metadata and associated tooling."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from contextlib import contextmanager
|
||||
import copy
|
||||
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import _tpu_estimator_embedding
|
||||
from tensorflow.python.tpu import device_assignment as tpu_device_assignment
|
||||
from tensorflow.python.tpu import tpu_config
|
||||
from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
|
||||
|
||||
|
||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||
_LOCAL_MASTERS = ('', 'local')
|
||||
_NUM_CORES_TO_COMPUTATION_SHAPE = {
|
||||
1: [1, 1, 1],
|
||||
2: [1, 1, 2],
|
||||
4: [1, 2, 2],
|
||||
8: [2, 2, 2],
|
||||
16: [4, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
class TPUContext(object):
|
||||
"""A context that holds the current configuration of the TPU computation."""
|
||||
|
||||
def __init__(self,
|
||||
internal_ctx,
|
||||
input_device=None,
|
||||
invocation_index=None,
|
||||
call_from_input_fn=True):
|
||||
self._internal_ctx = internal_ctx
|
||||
self._input_device = input_device
|
||||
self._invocation_index = invocation_index
|
||||
self._call_from_input_fn = call_from_input_fn
|
||||
|
||||
def current_input_fn_deployment(self):
|
||||
"""The configuration of the current input_fn invocation.
|
||||
|
||||
The configuration depends on `TPUConfig.per_host_input_for_training`. See
|
||||
`TPUConfig` for details.
|
||||
|
||||
Only set in params dict of input_fn
|
||||
|
||||
Returns:
|
||||
A tuple of
|
||||
1. Device spec string: String, is the current CPU host where the
|
||||
input_fn is invoked.
|
||||
2. Current invocation index: Int, 0-based index of the input_fn
|
||||
invocation. See next item for details.
|
||||
3. Total invocation count: Int, the total number of times to invoke the
|
||||
input_fn on all CPU hosts. Each invocation will be passed with a new
|
||||
`TPUContext` instance with current invocation index set properly.
|
||||
4. Total number of replicas consumed by current_invocation: Int, the
|
||||
number of replicas fed by the data returned by current input_fn. For
|
||||
example, for per_core input pipeline deployment
|
||||
and non-model-parallelism, total invocation count is equal to
|
||||
the number of cores in the system and num replicas consumed by
|
||||
current invocation is 1. For per-host v2 input pipeline deployment,
|
||||
total invocation count is equal to the number of hosts in the system
|
||||
and num replicas consumed by current invocation is equal to number of
|
||||
cores per host.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If this method must not be called from input_fn.
|
||||
"""
|
||||
if not self._call_from_input_fn:
|
||||
raise RuntimeError('This TPUContext instance must not be called from'
|
||||
' model_fn.')
|
||||
|
||||
if self._internal_ctx.is_input_sharded_per_core():
|
||||
total_invocation_count = (self._internal_ctx.num_hosts
|
||||
* self._internal_ctx.num_of_replicas_per_host)
|
||||
replicas_consumed = 1
|
||||
elif self._internal_ctx.is_input_broadcast_with_iterators():
|
||||
total_invocation_count = 1
|
||||
replicas_consumed = self._internal_ctx.num_replicas
|
||||
else:
|
||||
total_invocation_count = self._internal_ctx.num_hosts
|
||||
replicas_consumed = self._internal_ctx.num_of_replicas_per_host
|
||||
return (self._input_device, self._invocation_index,
|
||||
total_invocation_count, replicas_consumed)
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
"""The total number of replicas.
|
||||
|
||||
For non-model-parallelism, num_replicas should be the total num of TPU
|
||||
cores in the system.
|
||||
|
||||
Returns:
|
||||
The number of replicas.
|
||||
"""
|
||||
return self._internal_ctx.num_replicas
|
||||
|
||||
@property
|
||||
def num_hosts(self):
|
||||
"""The number of hosts for the TPU system."""
|
||||
return self._internal_ctx.num_hosts
|
||||
|
||||
@property
|
||||
def current_host(self):
|
||||
"""The current host index for the TPU system."""
|
||||
return self._invocation_index
|
||||
|
||||
@property
|
||||
def num_of_replicas_per_host(self):
|
||||
"""The number of replicas for each host."""
|
||||
if self._internal_ctx.model_parallelism_enabled:
|
||||
raise ValueError(
|
||||
'num_of_replicas_per_host is not supported for model_parallelism')
|
||||
return self._internal_ctx.num_of_replicas_per_host
|
||||
|
||||
@property
|
||||
def device_assignment(self):
|
||||
"""Returns device_assignment object."""
|
||||
if self._call_from_input_fn:
|
||||
raise RuntimeError('This TPUContext instance must not be called from'
|
||||
' input_fn.')
|
||||
return self._internal_ctx.device_assignment
|
||||
|
||||
def device_for_replica(self, replica_id):
|
||||
"""Returns the tuple of (CPU device and device ordinal) for replica.
|
||||
|
||||
This should be used for full replicate for non-model-parallelism.
|
||||
|
||||
Args:
|
||||
replica_id: Int, the replica index.
|
||||
|
||||
Returns:
|
||||
A tuple of device spec for CPU device and int device ordinal.
|
||||
"""
|
||||
# Note that: For the non-model parallelism, the mapping could be
|
||||
# a random permutation. The order should not matter in most cases
|
||||
# as far as model is replicated to all cores in the system.
|
||||
return self._internal_ctx.device_for_replica(replica_id)
|
||||
|
||||
@property
|
||||
def tpu_host_placement_function(self):
|
||||
"""Returns the TPU host place function.
|
||||
|
||||
The place function takes host_id as the input and returns the TF device
|
||||
for the correspoding host.
|
||||
"""
|
||||
|
||||
def _placement_function(host_id):
|
||||
"""Return the host device given host_id."""
|
||||
return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
|
||||
class _InternalTPUContext(object):
|
||||
"""A context holds immutable states of TPU computation.
|
||||
|
||||
This immutable object holds TPUEstimator config, train/eval batch size, and
|
||||
`TPUEstimator.use_tpu`, which is expected to be passed around. It also
|
||||
provides utility functions, based on the current state, to determine other
|
||||
information commonly required by TPU computation, such as TPU device names,
|
||||
TPU hosts, shard batch size, etc.
|
||||
|
||||
if eval_on_tpu is False, then execution of eval on TPU is disabled.
|
||||
if eval_on_tpu is True, but use_tpu is False, a warning is issued,
|
||||
and TPU execution is disabled for all modes.
|
||||
|
||||
N.B. As `mode` is not immutable state in Estimator, but essential to
|
||||
distinguish between TPU training and evaluation, a common usage for
|
||||
_InternalTPUContext with `mode` is as follows:
|
||||
```
|
||||
with _ctx.with_mode(mode) as ctx:
|
||||
if ctx.is_running_on_cpu():
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
train_batch_size,
|
||||
eval_batch_size,
|
||||
predict_batch_size,
|
||||
use_tpu,
|
||||
eval_on_tpu=True,
|
||||
embedding_config_spec=None):
|
||||
self._config = config
|
||||
self._train_batch_size = train_batch_size
|
||||
self._eval_batch_size = eval_batch_size
|
||||
self._predict_batch_size = predict_batch_size
|
||||
self._use_tpu = use_tpu
|
||||
logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)
|
||||
if not use_tpu and eval_on_tpu:
|
||||
logging.warning('eval_on_tpu ignored because use_tpu is False.')
|
||||
|
||||
self._eval_on_tpu = eval_on_tpu
|
||||
self._model_parallelism_enabled = (
|
||||
use_tpu and config.tpu_config.num_cores_per_replica)
|
||||
self._mode = None
|
||||
num_cores_per_replica = config.tpu_config.num_cores_per_replica
|
||||
if self._model_parallelism_enabled:
|
||||
self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
|
||||
num_cores_per_replica]
|
||||
else:
|
||||
self._computation_shape = None
|
||||
self._lazy_tpu_system_metadata_dict = {} # key by master address
|
||||
self._lazy_device_assignment_dict = {} # key by master address
|
||||
self._lazy_validation_dict = {} # key by ModeKeys
|
||||
self._embedding_config_spec = embedding_config_spec
|
||||
self._lazy_embedding_config_dict = {} # key by master address
|
||||
|
||||
def _assert_mode(self):
|
||||
if self._mode is None:
|
||||
raise RuntimeError(
|
||||
'`mode` needs to be set via contextmanager `with_mode`.')
|
||||
return self._mode
|
||||
|
||||
@contextmanager
|
||||
def with_mode(self, mode):
|
||||
# NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,
|
||||
# such as _lazy_tpu_system_metadata_dict between new copy and the original
|
||||
# one. Note that all lazy states stored in properties _lazy_foo are sort of
|
||||
# immutable as they should be same for the process lifetime.
|
||||
new_ctx = copy.copy(self)
|
||||
new_ctx._mode = mode # pylint: disable=protected-access
|
||||
yield new_ctx
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
return self._assert_mode()
|
||||
|
||||
def _get_master_address(self):
|
||||
mode = self._assert_mode()
|
||||
config = self._config
|
||||
master = (
|
||||
config.master
|
||||
if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)
|
||||
return master
|
||||
|
||||
def _get_tpu_system_metadata(self):
|
||||
"""Gets the (maybe cached) TPU system metadata."""
|
||||
master = self._get_master_address()
|
||||
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
|
||||
if tpu_system_metadata is not None:
|
||||
return tpu_system_metadata
|
||||
|
||||
cluster_def = None
|
||||
if (self._config.session_config and
|
||||
self._config.session_config.cluster_def.job):
|
||||
cluster_def = self._config.session_config.cluster_def
|
||||
|
||||
# pylint: disable=protected-access
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._query_tpu_system_metadata(
|
||||
master,
|
||||
cluster_def=cluster_def,
|
||||
query_topology=self.model_parallelism_enabled))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
|
||||
def _get_device_assignment(self):
|
||||
"""Gets the (maybe cached) TPU device assignment."""
|
||||
master = self._get_master_address()
|
||||
device_assignment = self._lazy_device_assignment_dict.get(master)
|
||||
if device_assignment is not None:
|
||||
return device_assignment
|
||||
|
||||
tpu_system_metadata = self._get_tpu_system_metadata()
|
||||
|
||||
device_assignment = tpu_device_assignment.device_assignment(
|
||||
tpu_system_metadata.topology,
|
||||
computation_shape=self._computation_shape,
|
||||
num_replicas=self.num_replicas)
|
||||
|
||||
logging.info('num_cores_per_replica: %s',
|
||||
str(self._config.tpu_config.num_cores_per_replica))
|
||||
logging.info('computation_shape: %s', str(self._computation_shape))
|
||||
logging.info('num_replicas: %d', self.num_replicas)
|
||||
logging.info('device_assignment.topology.device_coordinates: %s',
|
||||
str(device_assignment.topology.device_coordinates))
|
||||
logging.info('device_assignment.core_assignment: %s',
|
||||
str(device_assignment.core_assignment))
|
||||
|
||||
self._lazy_device_assignment_dict[master] = device_assignment
|
||||
return device_assignment
|
||||
|
||||
@property
|
||||
def embedding_config(self):
|
||||
"""Returns the embedding config based on current mode."""
|
||||
master = self._get_master_address()
|
||||
if master in self._lazy_embedding_config_dict:
|
||||
embedding_config = self._lazy_embedding_config_dict[master]
|
||||
else:
|
||||
embedding_config = None
|
||||
if self._use_tpu and self._embedding_config_spec:
|
||||
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
|
||||
self._embedding_config_spec, self._train_batch_size,
|
||||
self._eval_batch_size, self.num_hosts, self.num_cores, self.config)
|
||||
if not embedding_config.has_embedding_tables():
|
||||
embedding_config = None
|
||||
self._lazy_embedding_config_dict[master] = embedding_config
|
||||
|
||||
if embedding_config is not None:
|
||||
mode = self._assert_mode()
|
||||
# Dynamically attach tpu_embedding based on mode. With
|
||||
# this, we could keep embedding_config immutable but call site always
|
||||
# accesses the unified API '.tpu_embedding'.
|
||||
embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)
|
||||
return embedding_config
|
||||
|
||||
@property
|
||||
def model_parallelism_enabled(self):
|
||||
return self._model_parallelism_enabled
|
||||
|
||||
@property
|
||||
def input_partition_dims(self):
|
||||
return self._config.tpu_config.input_partition_dims
|
||||
|
||||
@property
|
||||
def device_assignment(self):
|
||||
return (self._get_device_assignment()
|
||||
if self._model_parallelism_enabled else None)
|
||||
|
||||
@property
|
||||
def num_of_cores_per_host(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_of_cores_per_host
|
||||
|
||||
@property
|
||||
def num_cores(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_cores
|
||||
|
||||
@property
|
||||
def num_of_replicas_per_host(self):
|
||||
"""Return the number of replicas per host."""
|
||||
if self.model_parallelism_enabled:
|
||||
return self.num_replicas // self.num_hosts
|
||||
else:
|
||||
return self.num_of_cores_per_host
|
||||
|
||||
@property
|
||||
def num_replicas(self):
|
||||
num_cores_in_system = self.num_cores
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
|
||||
if num_cores_per_replica > num_cores_in_system:
|
||||
raise ValueError(
|
||||
'The num of cores required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica, is larger than the total num of '
|
||||
'TPU cores in the system. num_cores_per_replica: {}, num cores '
|
||||
'in the system: {}'.format(num_cores_per_replica,
|
||||
num_cores_in_system))
|
||||
|
||||
if num_cores_in_system % num_cores_per_replica != 0:
|
||||
raise RuntimeError(
|
||||
'The num of cores in the system ({}) is not divisible by the num '
|
||||
'of cores ({}) required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica. This should never happen!'.format(
|
||||
num_cores_in_system, num_cores_per_replica))
|
||||
|
||||
return num_cores_in_system // num_cores_per_replica
|
||||
else:
|
||||
return num_cores_in_system
|
||||
|
||||
@property
|
||||
def num_hosts(self):
|
||||
metadata = self._get_tpu_system_metadata()
|
||||
return metadata.num_hosts
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self._config
|
||||
|
||||
def is_input_sharded_per_core(self):
|
||||
"""Return true if input_fn is invoked per-core (other than per-host)."""
|
||||
mode = self._assert_mode()
|
||||
return (mode == model_fn_lib.ModeKeys.TRAIN and
|
||||
(self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.PER_SHARD_V1))
|
||||
|
||||
def is_input_per_host_with_iterators(self):
|
||||
"""Return true if input_fn should be run in the per-host v2 config."""
|
||||
return (self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.PER_HOST_V2)
|
||||
|
||||
def is_input_broadcast_with_iterators(self):
|
||||
"""Return true if input_fn should be run in the full_replicae config."""
|
||||
mode = self._assert_mode()
|
||||
return ((self._config.tpu_config.per_host_input_for_training is
|
||||
tpu_config.InputPipelineConfig.BROADCAST) or
|
||||
(mode != model_fn_lib.ModeKeys.TRAIN and
|
||||
self._config.tpu_config.eval_training_input_configuration is
|
||||
tpu_config.InputPipelineConfig.SLICED))
|
||||
|
||||
def is_running_on_cpu(self, is_export_mode=False):
|
||||
"""Determines whether the input_fn and model_fn should be invoked on CPU.
|
||||
|
||||
This API also validates user provided configuration, such as batch size,
|
||||
according the lazy initialized TPU system metadata.
|
||||
|
||||
Args:
|
||||
is_export_mode: Indicates whether the current mode is for exporting the
|
||||
model, when mode == PREDICT. Only with this bool, we could
|
||||
tell whether user is calling the Estimator.predict or
|
||||
Estimator.export_savedmodel, which are running on TPU and CPU
|
||||
respectively. Parent class Estimator does not distinguish these two.
|
||||
|
||||
Returns:
|
||||
bool, whether current input_fn or model_fn should be running on CPU.
|
||||
|
||||
Raises:
|
||||
ValueError: any configuration is invalid.
|
||||
"""
|
||||
|
||||
is_running_on_cpu = self._is_running_on_cpu(is_export_mode)
|
||||
if not is_running_on_cpu:
|
||||
self._validate_tpu_configuration()
|
||||
return is_running_on_cpu
|
||||
|
||||
def _is_running_on_cpu(self, is_export_mode):
|
||||
"""Determines whether the input_fn and model_fn should be invoked on CPU."""
|
||||
mode = self._assert_mode()
|
||||
|
||||
if not self._use_tpu:
|
||||
return True
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:
|
||||
logging.info('_is_running_on_cpu: eval_on_tpu disabled')
|
||||
return True
|
||||
|
||||
if is_export_mode:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def global_batch_size(self):
|
||||
mode = self._assert_mode()
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
return self._train_batch_size
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
return self._eval_batch_size
|
||||
elif mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return self._predict_batch_size
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def batch_size_for_input_fn(self):
|
||||
"""Returns the shard batch size for `input_fn`."""
|
||||
global_batch_size = self.global_batch_size
|
||||
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
|
||||
return global_batch_size
|
||||
|
||||
# On TPU
|
||||
if self.is_input_sharded_per_core() or (
|
||||
self.is_input_per_host_with_iterators()):
|
||||
return global_batch_size // self.num_replicas
|
||||
else:
|
||||
return global_batch_size // self.num_hosts
|
||||
|
||||
@property
|
||||
def batch_size_for_model_fn(self):
|
||||
"""Returns the shard batch size for `model_fn`."""
|
||||
global_batch_size = self.global_batch_size
|
||||
|
||||
if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):
|
||||
return global_batch_size
|
||||
|
||||
# On TPU. always sharded per shard.
|
||||
return global_batch_size // self.num_replicas
|
||||
|
||||
@property
|
||||
def master_job(self):
|
||||
"""Returns the job name to use to place TPU computations on.
|
||||
|
||||
Returns:
|
||||
A string containing the job name, or None if no job should be specified.
|
||||
|
||||
Raises:
|
||||
ValueError: If the user needs to specify a tpu_job_name, because we are
|
||||
unable to infer the job name automatically, or if the user-specified job
|
||||
names are inappropriate.
|
||||
"""
|
||||
run_config = self._config
|
||||
# If the user specifies the tpu_job_name, use that.
|
||||
if run_config.tpu_config.tpu_job_name:
|
||||
return run_config.tpu_config.tpu_job_name
|
||||
|
||||
# The tpu job is determined by the run_config. Right now, this method is
|
||||
# required as tpu_config is not part of the RunConfig.
|
||||
mode = self._assert_mode()
|
||||
master = (
|
||||
run_config.evaluation_master
|
||||
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
|
||||
cluster_def = (run_config.session_config.cluster_def
|
||||
if run_config.session_config else None)
|
||||
|
||||
return tpu_system_metadata_lib.master_job(master, cluster_def)
|
||||
|
||||
@property
|
||||
def tpu_host_placement_function(self):
|
||||
"""Returns the TPU host place function."""
|
||||
|
||||
master = self.master_job
|
||||
|
||||
def _placement_function(_sentinal=None, replica_id=None, host_id=None): # pylint: disable=invalid-name
|
||||
"""Return the host device given replica_id or host_id."""
|
||||
assert _sentinal is None
|
||||
if replica_id is not None and host_id is not None:
|
||||
raise RuntimeError(
|
||||
'replica_id and host_id can have only one non-None value.')
|
||||
|
||||
if master is None:
|
||||
return '/replica:0/task:0/device:CPU:0'
|
||||
else:
|
||||
if replica_id is not None:
|
||||
if self.model_parallelism_enabled:
|
||||
return self.device_assignment.host_device(
|
||||
replica=replica_id, job=master)
|
||||
else:
|
||||
host_id = replica_id / self.num_of_cores_per_host
|
||||
|
||||
return '/job:%s/task:%d/device:CPU:0' % (master, host_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
@property
|
||||
def tpu_device_placement_function(self):
|
||||
"""Returns a TPU device placement Fn."""
|
||||
master = self.master_job
|
||||
job_device = '' if master is None else ('/job:%s' % master)
|
||||
|
||||
def _placement_function(i):
|
||||
if self.model_parallelism_enabled:
|
||||
return self.device_assignment.tpu_device(replica=i, job=master)
|
||||
else:
|
||||
num_of_cores_per_host = self.num_of_cores_per_host
|
||||
host_id = i / num_of_cores_per_host
|
||||
ordinal_id = i % num_of_cores_per_host
|
||||
return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)
|
||||
|
||||
return _placement_function
|
||||
|
||||
def tpu_ordinal_function(self, host_id):
|
||||
"""Returns the TPU ordinal fn."""
|
||||
|
||||
def _tpu_ordinal_function(shard_index_in_host):
|
||||
"""Return the TPU ordinal associated with a shard.
|
||||
|
||||
Required because the enqueue ops are placed on CPU.
|
||||
|
||||
Args:
|
||||
shard_index_in_host: the shard index
|
||||
|
||||
Returns:
|
||||
The ordinal of the TPU device the shard's infeed should be placed on.
|
||||
"""
|
||||
if self.model_parallelism_enabled:
|
||||
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
|
||||
replica = self.device_assignment.lookup_replicas(host_id,
|
||||
0)[shard_index_in_host]
|
||||
return self.device_assignment.tpu_ordinal(replica=replica)
|
||||
else:
|
||||
return shard_index_in_host % self.num_of_cores_per_host
|
||||
|
||||
return _tpu_ordinal_function
|
||||
|
||||
def _validate_tpu_configuration(self):
|
||||
"""Validates the configuration based on the TPU system metadata."""
|
||||
mode = self._assert_mode()
|
||||
if self._lazy_validation_dict.get(mode):
|
||||
return
|
||||
|
||||
# All following information is obtained from TPU system metadata.
|
||||
num_cores = self.num_cores
|
||||
num_replicas = self.num_replicas
|
||||
num_hosts = self.num_hosts
|
||||
|
||||
if not num_cores:
|
||||
tpu_system_metadata = self._get_tpu_system_metadata()
|
||||
raise RuntimeError(
|
||||
'Cannot find any TPU cores in the system. Please double check '
|
||||
'Tensorflow master address and TPU worker(s). Available devices '
|
||||
'are {}.'.format(tpu_system_metadata.devices))
|
||||
|
||||
if self._config.tpu_config.num_shards:
|
||||
user_provided_num_replicas = self._config.tpu_config.num_shards
|
||||
if user_provided_num_replicas != num_replicas:
|
||||
message = (
|
||||
'TPUConfig.num_shards is not set correctly. According to TPU '
|
||||
'system metadata for Tensorflow master ({}): num_replicas should '
|
||||
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
|
||||
'be the total num of TPU cores in the system. For '
|
||||
'model-parallelism, the total number of TPU cores should be '
|
||||
'num_cores_per_replica * num_replicas. Please set it '
|
||||
'accordingly or leave it as `None`'.format(
|
||||
self._get_master_address(), num_replicas,
|
||||
user_provided_num_replicas))
|
||||
|
||||
raise ValueError(message)
|
||||
|
||||
if self._config.tpu_config.num_cores_per_replica:
|
||||
num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
|
||||
num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
|
||||
if num_cores_per_replica > num_cores_per_host:
|
||||
raise ValueError(
|
||||
'The num of cores required by the model parallelism, specified by '
|
||||
'TPUConfig.num_cores_per_replica, is larger than the '
|
||||
'num_cores_per_host. num_cores_per_replica: {}, '
|
||||
'num_cores_per_host: {}'.format(num_cores_per_replica,
|
||||
num_cores_per_host))
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
if (self._train_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'train batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._train_batch_size, num_replicas))
|
||||
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
if self._eval_batch_size is None:
|
||||
raise ValueError(
|
||||
'eval_batch_size in TPUEstimator constructor cannot be `None`'
|
||||
'if .evaluate is running on TPU.')
|
||||
if (self._eval_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'eval batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._eval_batch_size, num_replicas))
|
||||
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
|
||||
raise ValueError(
|
||||
'TPUEstimator.evaluate should be running on single TPU'
|
||||
' instead of a Pod.')
|
||||
else:
|
||||
assert mode == model_fn_lib.ModeKeys.PREDICT
|
||||
if self._predict_batch_size is None:
|
||||
raise ValueError(
|
||||
'predict_batch_size in TPUEstimator constructor should not be '
|
||||
'`None` if .predict is running on TPU.')
|
||||
if (self._predict_batch_size % num_replicas != 0 and
|
||||
not self.is_input_broadcast_with_iterators()):
|
||||
raise ValueError(
|
||||
'predict batch size {} must be divisible by number of replicas {}'
|
||||
.format(self._predict_batch_size, num_replicas))
|
||||
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
|
||||
raise ValueError(
|
||||
'TPUEstimator.predict should be running on single TPU worker. '
|
||||
'got {}.'.format(num_hosts))
|
||||
|
||||
# Record the state "validated" into lazy dictionary.
|
||||
self._lazy_validation_dict[mode] = True
|
||||
|
||||
def device_for_replica(self, replica_id):
|
||||
"""Returns the tuple of (CPU device and device ordinal) for replica.
|
||||
|
||||
This should be used for full replicate for non-model-parallelism.
|
||||
|
||||
Args:
|
||||
replica_id: Int, the replica index.
|
||||
|
||||
Returns:
|
||||
A tuple of device spec for CPU device and int device ordinal.
|
||||
"""
|
||||
master = self.master_job
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
return (self.device_assignment.host_device(
|
||||
replica=replica_id, job=master),
|
||||
self.device_assignment.tpu_ordinal(replica=replica_id))
|
||||
|
||||
job_device = '' if master is None else ('/job:%s' % master)
|
||||
|
||||
num_of_replicas_per_host = self.num_of_replicas_per_host
|
||||
host_id = replica_id / num_of_replicas_per_host
|
||||
ordinal_id = replica_id % num_of_replicas_per_host
|
||||
|
||||
host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)
|
||||
return (host_device, ordinal_id)
|
||||
|
||||
|
||||
class _OneCoreTPUContext(_InternalTPUContext):
|
||||
"""Special _InternalTPUContext for one core usage."""
|
||||
|
||||
def __init__(self, config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu):
|
||||
|
||||
super(_OneCoreTPUContext, self).__init__(
|
||||
config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu)
|
||||
|
||||
def _get_tpu_system_metadata(self):
|
||||
"""Gets the (maybe cached) TPU system metadata."""
|
||||
master = self._get_master_address()
|
||||
tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)
|
||||
if tpu_system_metadata is not None:
|
||||
return tpu_system_metadata
|
||||
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._TPUSystemMetadata( # pylint: disable=protected-access
|
||||
num_cores=1,
|
||||
num_hosts=1,
|
||||
num_of_cores_per_host=1,
|
||||
topology=None,
|
||||
devices=[]))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
|
||||
|
||||
def _get_tpu_context(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu, eval_on_tpu,
|
||||
embedding_config_spec):
|
||||
"""Returns an instance of `_InternalTPUContext`."""
|
||||
|
||||
if (config.tpu_config.num_shards == 1 and
|
||||
config.tpu_config.num_cores_per_replica is None):
|
||||
if embedding_config_spec is not None:
|
||||
raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '
|
||||
'when embedding_config_spec is not None.')
|
||||
logging.warning(
|
||||
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
|
||||
'Please fix as soon as possible (leaving num_shards as None.)')
|
||||
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu)
|
||||
|
||||
return _InternalTPUContext(config, train_batch_size, eval_batch_size,
|
||||
predict_batch_size, use_tpu, eval_on_tpu,
|
||||
embedding_config_spec)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow_estimator.python.estimator.tpu.tpu_context import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,339 +0,0 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TPU Estimator Signalling Tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu_estimator
|
||||
|
||||
|
||||
def make_input_fn(num_samples):
|
||||
a = np.linspace(0, 100.0, num=num_samples)
|
||||
b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
|
||||
|
||||
def input_fn(params):
|
||||
batch_size = params['batch_size']
|
||||
da1 = dataset_ops.Dataset.from_tensor_slices(a)
|
||||
da2 = dataset_ops.Dataset.from_tensor_slices(b)
|
||||
|
||||
dataset = dataset_ops.Dataset.zip((da1, da2))
|
||||
dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb})
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
return input_fn, (a, b)
|
||||
|
||||
|
||||
def make_input_fn_with_labels(num_samples):
|
||||
a = np.linspace(0, 100.0, num=num_samples)
|
||||
b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))
|
||||
|
||||
def input_fn(params):
|
||||
batch_size = params['batch_size']
|
||||
da1 = dataset_ops.Dataset.from_tensor_slices(a)
|
||||
da2 = dataset_ops.Dataset.from_tensor_slices(b)
|
||||
|
||||
dataset = dataset_ops.Dataset.zip((da1, da2))
|
||||
dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb))
|
||||
dataset = dataset.batch(batch_size)
|
||||
return dataset
|
||||
return input_fn, (a, b)
|
||||
|
||||
|
||||
class TPUEstimatorStoppingSignalsTest(test.TestCase):
|
||||
|
||||
def test_normal_output_without_signals(self):
|
||||
num_samples = 4
|
||||
batch_size = 2
|
||||
|
||||
params = {'batch_size': batch_size}
|
||||
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
dataset = input_fn(params)
|
||||
features = dataset_ops.make_one_shot_iterator(dataset).get_next()
|
||||
|
||||
# With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
|
||||
self.assertIsNone(features['a'].shape.as_list()[0])
|
||||
|
||||
with session.Session() as sess:
|
||||
result = sess.run(features)
|
||||
self.assertAllEqual(a[:batch_size], result['a'])
|
||||
self.assertAllEqual(b[:batch_size], result['b'])
|
||||
|
||||
# This run should work as num_samples / batch_size = 2.
|
||||
result = sess.run(features)
|
||||
self.assertAllEqual(a[batch_size:num_samples], result['a'])
|
||||
self.assertAllEqual(b[batch_size:num_samples], result['b'])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
# Given num_samples and batch_size, this run should fail.
|
||||
sess.run(features)
|
||||
|
||||
def test_output_with_stopping_signals(self):
|
||||
num_samples = 4
|
||||
batch_size = 2
|
||||
|
||||
params = {'batch_size': batch_size}
|
||||
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
dataset = input_fn(params)
|
||||
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size)
|
||||
dataset_initializer = inputs.dataset_initializer()
|
||||
features, _ = inputs.features_and_labels()
|
||||
signals = inputs.signals()
|
||||
|
||||
# With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.
|
||||
self.assertIsNone(features['a'].shape.as_list()[0])
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(dataset_initializer)
|
||||
|
||||
result, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual(a[:batch_size], result['a'])
|
||||
self.assertAllEqual(b[:batch_size], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
# This run should work as num_samples / batch_size = 2.
|
||||
result, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual(a[batch_size:num_samples], result['a'])
|
||||
self.assertAllEqual(b[batch_size:num_samples], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
# This run should work, *but* see STOP ('1') as signals
|
||||
_, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(features)
|
||||
|
||||
|
||||
class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase):
|
||||
|
||||
def test_num_samples_divisible_by_batch_size(self):
|
||||
num_samples = 4
|
||||
batch_size = 2
|
||||
|
||||
params = {'batch_size': batch_size}
|
||||
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
dataset = input_fn(params)
|
||||
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
|
||||
add_padding=True)
|
||||
dataset_initializer = inputs.dataset_initializer()
|
||||
features, _ = inputs.features_and_labels()
|
||||
signals = inputs.signals()
|
||||
|
||||
# With padding, all shapes are static now.
|
||||
self.assertEqual(batch_size, features['a'].shape.as_list()[0])
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(dataset_initializer)
|
||||
|
||||
result, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual(a[:batch_size], result['a'])
|
||||
self.assertAllEqual(b[:batch_size], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([0.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
|
||||
# This run should work as num_samples / batch_size = 2.
|
||||
result, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual(a[batch_size:num_samples], result['a'])
|
||||
self.assertAllEqual(b[batch_size:num_samples], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([0.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
|
||||
# This run should work, *but* see STOP ('1') as signals
|
||||
_, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(features)
|
||||
|
||||
def test_num_samples_not_divisible_by_batch_size(self):
|
||||
num_samples = 5
|
||||
batch_size = 2
|
||||
|
||||
params = {'batch_size': batch_size}
|
||||
input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
dataset = input_fn(params)
|
||||
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
|
||||
add_padding=True)
|
||||
dataset_initializer = inputs.dataset_initializer()
|
||||
features, labels = inputs.features_and_labels()
|
||||
signals = inputs.signals()
|
||||
|
||||
# With padding, all shapes are static.
|
||||
self.assertEqual(batch_size, features['a'].shape.as_list()[0])
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(dataset_initializer)
|
||||
|
||||
evaluated_features, evaluated_labels, evaluated_signals = (
|
||||
sess.run([features, labels, signals]))
|
||||
self.assertAllEqual(a[:batch_size], evaluated_features['a'])
|
||||
self.assertAllEqual(b[:batch_size], evaluated_labels)
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([0.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
|
||||
# This run should work as num_samples / batch_size >= 2.
|
||||
evaluated_features, evaluated_labels, evaluated_signals = (
|
||||
sess.run([features, labels, signals]))
|
||||
self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a'])
|
||||
self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels)
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([0.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
|
||||
# This is the final partial batch.
|
||||
evaluated_features, evaluated_labels, evaluated_signals = (
|
||||
sess.run([features, labels, signals]))
|
||||
real_batch_size = num_samples % batch_size
|
||||
|
||||
# Assert the real part.
|
||||
self.assertAllEqual(a[2*batch_size:num_samples],
|
||||
evaluated_features['a'][:real_batch_size])
|
||||
self.assertAllEqual(b[2*batch_size:num_samples],
|
||||
evaluated_labels[:real_batch_size])
|
||||
# Assert the padded part.
|
||||
self.assertAllEqual([0.0] * (batch_size - real_batch_size),
|
||||
evaluated_features['a'][real_batch_size:])
|
||||
self.assertAllEqual([[0.0]] * (batch_size - real_batch_size),
|
||||
evaluated_labels[real_batch_size:])
|
||||
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
padding = ([.0] * real_batch_size
|
||||
+ [1.] * (batch_size - real_batch_size))
|
||||
self.assertAllEqual(padding, evaluated_signals['padding_mask'])
|
||||
|
||||
# This run should work, *but* see STOP ('1') as signals
|
||||
_, evaluated_signals = sess.run([features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(features)
|
||||
|
||||
def test_slice(self):
|
||||
num_samples = 3
|
||||
batch_size = 2
|
||||
|
||||
params = {'batch_size': batch_size}
|
||||
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
dataset = input_fn(params)
|
||||
inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size,
|
||||
add_padding=True)
|
||||
dataset_initializer = inputs.dataset_initializer()
|
||||
features, _ = inputs.features_and_labels()
|
||||
signals = inputs.signals()
|
||||
|
||||
sliced_features = (
|
||||
tpu_estimator._PaddingSignals.slice_tensor_or_dict(
|
||||
features, signals))
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(dataset_initializer)
|
||||
|
||||
result, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertAllEqual(a[:batch_size], result['a'])
|
||||
self.assertAllEqual(b[:batch_size], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
# This is the final partial batch.
|
||||
result, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertEqual(1, len(result['a']))
|
||||
self.assertAllEqual(a[batch_size:num_samples], result['a'])
|
||||
self.assertAllEqual(b[batch_size:num_samples], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
# This run should work, *but* see STOP ('1') as signals
|
||||
_, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(sliced_features)
|
||||
|
||||
def test_slice_with_multi_invocations_per_step(self):
|
||||
num_samples = 3
|
||||
batch_size = 2
|
||||
|
||||
params = {'batch_size': batch_size}
|
||||
input_fn, (a, b) = make_input_fn(num_samples=num_samples)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
dataset = input_fn(params)
|
||||
inputs = tpu_estimator._InputsWithStoppingSignals(
|
||||
dataset, batch_size, add_padding=True, num_invocations_per_step=2)
|
||||
dataset_initializer = inputs.dataset_initializer()
|
||||
features, _ = inputs.features_and_labels()
|
||||
signals = inputs.signals()
|
||||
|
||||
sliced_features = (
|
||||
tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(dataset_initializer)
|
||||
|
||||
result, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertAllEqual(a[:batch_size], result['a'])
|
||||
self.assertAllEqual(b[:batch_size], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
# This is the final partial batch.
|
||||
result, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertEqual(1, len(result['a']))
|
||||
self.assertAllEqual(a[batch_size:num_samples], result['a'])
|
||||
self.assertAllEqual(b[batch_size:num_samples], result['b'])
|
||||
self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])
|
||||
|
||||
# We should see 3 continuous batches with STOP ('1') as signals and all
|
||||
# of them have mask 1.
|
||||
_, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([1.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
|
||||
_, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([1.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
|
||||
_, evaluated_signals = sess.run([sliced_features, signals])
|
||||
self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])
|
||||
self.assertAllEqual([1.] * batch_size,
|
||||
evaluated_signals['padding_mask'])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(sliced_features)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1,51 +1,23 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===================================================================
|
||||
|
||||
"""Utilities for the functionalities."""
|
||||
# ==============================================================================
|
||||
"""Stub file to maintain backwards compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import training
|
||||
|
||||
def check_positive_integer(value, name):
|
||||
"""Checks whether `value` is a positive integer."""
|
||||
if not isinstance(value, six.integer_types):
|
||||
raise TypeError('{} must be int, got {}'.format(name, type(value)))
|
||||
|
||||
if value <= 0:
|
||||
raise ValueError('{} must be positive, got {}'.format(name, value))
|
||||
|
||||
|
||||
# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we
|
||||
# release a tensorflow_estimator with MultiHostDatasetInitializerHook in
|
||||
# python/estimator/util.py.
|
||||
class MultiHostDatasetInitializerHook(training.SessionRunHook):
|
||||
"""Creates a SessionRunHook that initializes all passed iterators."""
|
||||
|
||||
def __init__(self, dataset_initializers):
|
||||
self._initializers = dataset_initializers
|
||||
|
||||
def after_create_session(self, session, coord):
|
||||
del coord
|
||||
start = time.time()
|
||||
session.run(self._initializers)
|
||||
logging.info('Initialized dataset iterators in %d seconds',
|
||||
time.time() - start)
|
||||
# pylint: disable=wildcard-import,unused-import
|
||||
from tensorflow_estimator.python.estimator.tpu.util import *
|
||||
# pylint: enable=wildcard-import,unused-import
|
||||
|
Loading…
Reference in New Issue
Block a user