Merge branch 'master' into patch-1
This commit is contained in:
commit
348c4696d2
1
.gitignore
vendored
1
.gitignore
vendored
@ -29,6 +29,7 @@ Podfile.lock
|
||||
/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
|
||||
xcuserdata/**
|
||||
/api_init_files_list.txt
|
||||
/estimator_api_init_files_list.txt
|
||||
|
||||
# Android
|
||||
.gradle
|
||||
|
@ -107,7 +107,7 @@ diff <my_cc_file> /tmp/my_cc_file.cc
|
||||
#### Python coding style
|
||||
|
||||
Changes to TensorFlow Python code should conform to
|
||||
[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
|
||||
[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
|
||||
|
||||
Use `pylint` to check your Python changes. To install `pylint` and
|
||||
retrieve TensorFlow's custom style definition:
|
||||
|
@ -15,9 +15,10 @@ If you open a GitHub issue, here is our policy:
|
||||
### System information
|
||||
- **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**:
|
||||
- **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**:
|
||||
- **Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device**:
|
||||
- **TensorFlow installed from (source or binary)**:
|
||||
- **TensorFlow version (use command below)**:
|
||||
- **Python version**:
|
||||
- **Python version**:
|
||||
- **Bazel version (if compiling from source)**:
|
||||
- **GCC/Compiler version (if compiling from source)**:
|
||||
- **CUDA/cuDNN version**:
|
||||
|
22
RELEASE.md
22
RELEASE.md
@ -6,7 +6,7 @@
|
||||
* Update `tf.keras` to the Keras 2.1.6 API.
|
||||
* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082).
|
||||
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees).
|
||||
* The [python interface](https://tensorflow-dot-devsite.googleplex.com/versions/r1.9/api_docs/python/tf/contrib/lite)
|
||||
* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite)
|
||||
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md)
|
||||
has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again
|
||||
included in the standard `pip` installation.
|
||||
@ -21,7 +21,7 @@
|
||||
* The [distributions.Bijector](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/distributions/bijectors/Bijector)
|
||||
API supports broadcasting for Bijectors with new API changes.
|
||||
|
||||
## Breaking Chances
|
||||
## Breaking Changes
|
||||
* If you're opening empty variable scopes; replace `variable_scope('', ...)` by
|
||||
`variable_scope(tf.get_variable_scope(), ...)`.
|
||||
* Headers used for building custom ops have been moved from site-packages/external into site-packages/tensorflow/include/external.
|
||||
@ -34,18 +34,22 @@
|
||||
* Using `tf.layers` in a subclassed `tf.keras.Model` class. See
|
||||
[here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details
|
||||
* `tf.data`:
|
||||
* The `DatasetBase::DebugString()` method is now `const`.
|
||||
* Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets.
|
||||
* `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators.
|
||||
* `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
|
||||
* `tf.contrib.data.sample_from_datasets()` and `tf.contrib.data.choose_from_datasets()` make it easier to sample or deterministically choose elements from multiple datasets.
|
||||
* `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings, and two infrequently used arguments removed.
|
||||
* (C++) `DatasetBase::DebugString()` is now `const`.
|
||||
* (C++) `DatasetBase::MakeIterator()` has been renamed to `DatasetBase::MakeIteratorInternal()`.
|
||||
* (C++) `IteratorBase::Initialize()` method was added to support raising errors during iterator construction.
|
||||
* Eager Execution:
|
||||
* Added the ability to pause recording operations for gradient computation via `tf.GradientTape.stop_recording`.
|
||||
* Updated documentation, introductory notebooks.
|
||||
* `tf.keras`:
|
||||
* Move Keras code out of _impl folder and remove API files.
|
||||
* `tf.keras.Model.save_weights` now saves in TensorFlow format by default.
|
||||
* Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods.
|
||||
* Accelerated Linear Algebra (XLA):
|
||||
* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB).
|
||||
* TensorFlow Debugger (tfdbg) CLI: fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB).
|
||||
* `tf.contrib`:
|
||||
* Add `tf.contrib.data.choose_from_datasets()`.
|
||||
* `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`.
|
||||
* `tf.contrib.framework.zero_initializer` supports ResourceVariable.
|
||||
* Adding "constrained_optimization" to tensorflow/contrib.
|
||||
* Other:
|
||||
@ -55,7 +59,6 @@
|
||||
* More consistent GcsFileSystem behavior for certain reads past EOF.
|
||||
* Update benchmark for tf.scan to match ranges across eager and graph modes.
|
||||
* Fixed bug in `tf.reduce_prod gradient` for complex dtypes.
|
||||
* Add optional `args` argument to `Dataset.from_generator()`.
|
||||
* Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)").
|
||||
* Benchmark for tf.scan in graph and eager modes.
|
||||
* Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
|
||||
@ -65,7 +68,6 @@
|
||||
* LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`.
|
||||
* Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary.
|
||||
* Added `tf.train.Checkpoint` for reading/writing object-based checkpoints.
|
||||
* `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
|
||||
* Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product.
|
||||
* Allow LinearOperator to broadcast.
|
||||
* SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other.
|
||||
|
@ -18,7 +18,7 @@ closure_repositories()
|
||||
# files, in case the parsing of those build files depends on the bazel
|
||||
# version we require here.
|
||||
load("//tensorflow:version_check.bzl", "check_bazel_version_at_least")
|
||||
check_bazel_version_at_least("0.10.0")
|
||||
check_bazel_version_at_least("0.15.0")
|
||||
|
||||
load("//tensorflow:workspace.bzl", "tf_workspace")
|
||||
|
||||
|
120
configure.py
120
configure.py
@ -35,8 +35,8 @@ except ImportError:
|
||||
|
||||
_DEFAULT_CUDA_VERSION = '9.0'
|
||||
_DEFAULT_CUDNN_VERSION = '7'
|
||||
_DEFAULT_NCCL_VERSION = '1.3'
|
||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
|
||||
_DEFAULT_NCCL_VERSION = '2.2'
|
||||
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
|
||||
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
|
||||
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
|
||||
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
|
||||
@ -680,7 +680,7 @@ def create_android_sdk_rule(environ_cp):
|
||||
if is_windows() or is_cygwin():
|
||||
default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA'])
|
||||
elif is_macos():
|
||||
default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
|
||||
default_sdk_path = '%s/library/Android/Sdk' % environ_cp['HOME']
|
||||
else:
|
||||
default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME']
|
||||
|
||||
@ -835,6 +835,8 @@ def set_tf_cuda_version(environ_cp):
|
||||
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
|
||||
cuda_toolkit_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
|
||||
if is_windows() or is_cygwin():
|
||||
cuda_toolkit_path = cygpath(cuda_toolkit_path)
|
||||
|
||||
if is_windows():
|
||||
cuda_rt_lib_path = 'lib/x64/cudart.lib'
|
||||
@ -880,7 +882,7 @@ def set_tf_cudnn_version(environ_cp):
|
||||
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
|
||||
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
|
||||
'installed. Refer to README.md for more details. [Default'
|
||||
' is %s]:') % (tf_cudnn_version, default_cudnn_path)
|
||||
' is %s]: ') % (tf_cudnn_version, default_cudnn_path)
|
||||
cudnn_install_path = get_from_env_or_user_or_default(
|
||||
environ_cp, 'CUDNN_INSTALL_PATH', ask_cudnn_path, default_cudnn_path)
|
||||
|
||||
@ -1095,8 +1097,10 @@ def set_tf_nccl_install_path(environ_cp):
|
||||
raise ValueError('Currently NCCL is only supported on Linux platforms.')
|
||||
|
||||
ask_nccl_version = (
|
||||
'Please specify the NCCL version you want to use. '
|
||||
'[Leave empty to default to NCCL %s]: ') % _DEFAULT_NCCL_VERSION
|
||||
'Please specify the NCCL version you want to use. If NCCL %s is not '
|
||||
'installed, then you can use version 1.3 that can be fetched '
|
||||
'automatically but it may have worse performance with multiple GPUs. '
|
||||
'[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION)
|
||||
|
||||
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
|
||||
tf_nccl_version = get_from_env_or_user_or_default(
|
||||
@ -1197,7 +1201,7 @@ def set_tf_cuda_compute_capabilities(environ_cp):
|
||||
'https://developer.nvidia.com/cuda-gpus.\nPlease'
|
||||
' note that each additional compute '
|
||||
'capability significantly increases your '
|
||||
'build time and binary size. [Default is: %s]' %
|
||||
'build time and binary size. [Default is: %s]: ' %
|
||||
default_cuda_compute_capabilities)
|
||||
tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
|
||||
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
|
||||
@ -1232,28 +1236,13 @@ def set_tf_cuda_compute_capabilities(environ_cp):
|
||||
|
||||
def set_other_cuda_vars(environ_cp):
|
||||
"""Set other CUDA related variables."""
|
||||
if is_windows():
|
||||
# The following three variables are needed for MSVC toolchain configuration
|
||||
# in Bazel
|
||||
environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH')
|
||||
environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get(
|
||||
'TF_CUDA_COMPUTE_CAPABILITIES')
|
||||
environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1
|
||||
write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH'))
|
||||
write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE',
|
||||
environ_cp.get('CUDA_COMPUTE_CAPABILITIE'))
|
||||
write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION',
|
||||
environ_cp.get('NO_WHOLE_ARCHIVE_OPTION'))
|
||||
write_to_bazelrc('build --config=win-cuda')
|
||||
write_to_bazelrc('test --config=win-cuda')
|
||||
# If CUDA is enabled, always use GPU during build and test.
|
||||
if environ_cp.get('TF_CUDA_CLANG') == '1':
|
||||
write_to_bazelrc('build --config=cuda_clang')
|
||||
write_to_bazelrc('test --config=cuda_clang')
|
||||
else:
|
||||
# If CUDA is enabled, always use GPU during build and test.
|
||||
if environ_cp.get('TF_CUDA_CLANG') == '1':
|
||||
write_to_bazelrc('build --config=cuda_clang')
|
||||
write_to_bazelrc('test --config=cuda_clang')
|
||||
else:
|
||||
write_to_bazelrc('build --config=cuda')
|
||||
write_to_bazelrc('test --config=cuda')
|
||||
write_to_bazelrc('build --config=cuda')
|
||||
write_to_bazelrc('test --config=cuda')
|
||||
|
||||
|
||||
def set_host_cxx_compiler(environ_cp):
|
||||
@ -1413,14 +1402,36 @@ def set_build_strip_flag():
|
||||
write_to_bazelrc('build --strip=always')
|
||||
|
||||
|
||||
def set_windows_build_flags():
|
||||
if is_windows():
|
||||
# The non-monolithic build is not supported yet
|
||||
write_to_bazelrc('build --config monolithic')
|
||||
# Suppress warning messages
|
||||
write_to_bazelrc('build --copt=-w --host_copt=-w')
|
||||
# Output more verbose information when something goes wrong
|
||||
write_to_bazelrc('build --verbose_failures')
|
||||
def set_windows_build_flags(environ_cp):
|
||||
"""Set Windows specific build options."""
|
||||
# The non-monolithic build is not supported yet
|
||||
write_to_bazelrc('build --config monolithic')
|
||||
# Suppress warning messages
|
||||
write_to_bazelrc('build --copt=-w --host_copt=-w')
|
||||
# Output more verbose information when something goes wrong
|
||||
write_to_bazelrc('build --verbose_failures')
|
||||
# The host and target platforms are the same in Windows build. So we don't
|
||||
# have to distinct them. This avoids building the same targets twice.
|
||||
write_to_bazelrc('build --distinct_host_configuration=false')
|
||||
# Enable short object file path to avoid long path issue on Windows.
|
||||
# TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0
|
||||
# Short object file path will be enabled by default.
|
||||
write_to_bazelrc('build --experimental_shortened_obj_file_path=true')
|
||||
|
||||
if get_var(
|
||||
environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
|
||||
True,
|
||||
('Would you like to override eigen strong inline for some C++ '
|
||||
'compilation to reduce the compilation time?'),
|
||||
'Eigen strong inline overridden.',
|
||||
'Not overriding eigen strong inline, '
|
||||
'some compilations could take more than 20 mins.'):
|
||||
# Due to a known MSVC compiler issue
|
||||
# https://github.com/tensorflow/tensorflow/issues/10521
|
||||
# Overriding eigen strong inline speeds up the compiling of
|
||||
# conv_grad_ops_3d.cc and conv_ops_3d.cc by 20 minutes,
|
||||
# but this also hurts the performance. Let users decide what they want.
|
||||
write_to_bazelrc('build --define=override_eigen_strong_inline=true')
|
||||
|
||||
|
||||
def config_info_line(name, help_text):
|
||||
@ -1440,14 +1451,14 @@ def main():
|
||||
# environment variables.
|
||||
environ_cp = dict(os.environ)
|
||||
|
||||
check_bazel_version('0.10.0')
|
||||
check_bazel_version('0.15.0')
|
||||
|
||||
reset_tf_configure_bazelrc(args.workspace)
|
||||
cleanup_makefile()
|
||||
setup_python(environ_cp)
|
||||
|
||||
if is_windows():
|
||||
environ_cp['TF_NEED_S3'] = '0'
|
||||
environ_cp['TF_NEED_AWS'] = '0'
|
||||
environ_cp['TF_NEED_GCP'] = '0'
|
||||
environ_cp['TF_NEED_HDFS'] = '0'
|
||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||
@ -1460,19 +1471,31 @@ def main():
|
||||
# TODO(ibiryukov): Investigate using clang as a cpu or cuda compiler on
|
||||
# Windows.
|
||||
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
|
||||
environ_cp['TF_ENABLE_XLA'] = '0'
|
||||
environ_cp['TF_NEED_GDR'] = '0'
|
||||
environ_cp['TF_NEED_VERBS'] = '0'
|
||||
environ_cp['TF_NEED_MPI'] = '0'
|
||||
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
|
||||
|
||||
if is_macos():
|
||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||
environ_cp['TF_NEED_TENSORRT'] = '0'
|
||||
|
||||
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
|
||||
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
|
||||
# runtime to allow the Tensorflow testcases which compare numpy
|
||||
# results to Tensorflow results to succeed.
|
||||
if is_ppc64le():
|
||||
write_action_env_to_bazelrc("OMP_NUM_THREADS", 1)
|
||||
|
||||
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
|
||||
'with_jemalloc', True)
|
||||
set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
|
||||
'with_gcp_support', True, 'gcp')
|
||||
set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
|
||||
'with_hdfs_support', True, 'hdfs')
|
||||
set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
|
||||
'with_s3_support', True, 's3')
|
||||
set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
|
||||
'with_aws_support', True, 'aws')
|
||||
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
|
||||
'with_kafka_support', True, 'kafka')
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
@ -1536,7 +1559,8 @@ def main():
|
||||
set_grpc_build_flags()
|
||||
set_cc_opt_flags(environ_cp)
|
||||
set_build_strip_flag()
|
||||
set_windows_build_flags()
|
||||
if is_windows():
|
||||
set_windows_build_flags(environ_cp)
|
||||
|
||||
if get_var(
|
||||
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
|
||||
@ -1548,11 +1572,15 @@ def main():
|
||||
create_android_ndk_rule(environ_cp)
|
||||
create_android_sdk_rule(environ_cp)
|
||||
|
||||
print('Preconfigured Bazel build configs. You can use any of the below by '
|
||||
'adding "--config=<>" to your build command. See tools/bazel.rc for '
|
||||
'more details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
# On Windows, we don't have MKL support and the build is always monolithic.
|
||||
# So no need to print the following message.
|
||||
# TODO(pcloudy): remove the following if check when they make sense on Windows
|
||||
if not is_windows():
|
||||
print('Preconfigured Bazel build configs. You can use any of the below by '
|
||||
'adding "--config=<>" to your build command. See tools/bazel.rc for '
|
||||
'more details.')
|
||||
config_info_line('mkl', 'Build with MKL support.')
|
||||
config_info_line('monolithic', 'Config for mostly static monolithic build.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -20,10 +20,18 @@ load(
|
||||
"tf_additional_binary_deps",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/tools/api/generator:api_gen.bzl",
|
||||
"//tensorflow/python/tools/api/generator:api_gen.bzl",
|
||||
"gen_api_init_files", # @unused
|
||||
)
|
||||
|
||||
# Config setting used when building for products
|
||||
# which requires restricted licenses to be avoided.
|
||||
config_setting(
|
||||
name = "no_lgpl_deps",
|
||||
values = {"define": "__TENSORFLOW_NO_LGPL_DEPS__=1"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Config setting for determining if we are building for Android.
|
||||
config_setting(
|
||||
name = "android",
|
||||
@ -216,8 +224,8 @@ config_setting(
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_s3_support",
|
||||
define_values = {"with_s3_support": "true"},
|
||||
name = "with_aws_support",
|
||||
define_values = {"with_aws_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
@ -244,8 +252,8 @@ config_setting(
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_s3_support_windows_override",
|
||||
define_values = {"with_s3_support": "true"},
|
||||
name = "with_aws_support_windows_override",
|
||||
define_values = {"with_aws_support": "true"},
|
||||
values = {"cpu": "x64_windows"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
@ -279,8 +287,8 @@ config_setting(
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_s3_support_android_override",
|
||||
define_values = {"with_s3_support": "true"},
|
||||
name = "with_aws_support_android_override",
|
||||
define_values = {"with_aws_support": "true"},
|
||||
values = {"crosstool_top": "//external:android/crosstool"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
@ -300,8 +308,8 @@ config_setting(
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_s3_support_ios_override",
|
||||
define_values = {"with_s3_support": "true"},
|
||||
name = "with_aws_support_ios_override",
|
||||
define_values = {"with_aws_support": "true"},
|
||||
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eval_const_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -327,6 +328,7 @@ TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
|
||||
}
|
||||
|
||||
void TF_DeleteBuffer(TF_Buffer* buffer) {
|
||||
if (buffer == nullptr) return;
|
||||
if (buffer->data_deallocator != nullptr) {
|
||||
(*buffer->data_deallocator)(const_cast<void*>(buffer->data),
|
||||
buffer->length);
|
||||
@ -356,6 +358,7 @@ void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
|
||||
|
||||
void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
|
||||
status->status = Status::OK();
|
||||
if (s == nullptr) return;
|
||||
delete s->session;
|
||||
delete s;
|
||||
}
|
||||
@ -906,6 +909,7 @@ TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
|
||||
TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
|
||||
|
||||
void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
|
||||
if (lib_handle == nullptr) return;
|
||||
tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
|
||||
delete lib_handle;
|
||||
}
|
||||
@ -963,6 +967,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
|
||||
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
|
||||
nullptr);
|
||||
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
|
||||
TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
|
||||
|
||||
#undef TF_DEVICELIST_METHOD
|
||||
|
||||
@ -1852,6 +1857,7 @@ TF_Graph::TF_Graph()
|
||||
TF_Graph* TF_NewGraph() { return new TF_Graph; }
|
||||
|
||||
void TF_DeleteGraph(TF_Graph* g) {
|
||||
if (g == nullptr) return;
|
||||
g->mu.lock();
|
||||
g->delete_requested = true;
|
||||
const bool del = g->sessions.empty();
|
||||
@ -2068,7 +2074,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
|
||||
TF_Graph* graph, const TF_Buffer* graph_def,
|
||||
const TF_ImportGraphDefOptions* options, TF_Status* status) {
|
||||
GraphDef def;
|
||||
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) {
|
||||
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
|
||||
graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return nullptr;
|
||||
}
|
||||
@ -2098,7 +2105,8 @@ void TF_GraphImportGraphDefWithReturnOutputs(
|
||||
return;
|
||||
}
|
||||
GraphDef def;
|
||||
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) {
|
||||
if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
|
||||
graph_def->length)) {
|
||||
status->status = InvalidArgument("Invalid GraphDef");
|
||||
return;
|
||||
}
|
||||
@ -2525,6 +2533,7 @@ void TF_CloseSession(TF_Session* s, TF_Status* status) {
|
||||
|
||||
void TF_DeleteSession(TF_Session* s, TF_Status* status) {
|
||||
status->status = Status::OK();
|
||||
if (s == nullptr) return;
|
||||
TF_Graph* const graph = s->graph;
|
||||
if (graph != nullptr) {
|
||||
graph->mu.lock();
|
||||
@ -2723,7 +2732,34 @@ TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
|
||||
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(*api_def, ret);
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
return ret;
|
||||
#endif // __ANDROID__
|
||||
}
|
||||
|
||||
TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
|
||||
tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(kernel_list, ret);
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
|
||||
tensorflow::KernelList kernel_list =
|
||||
tensorflow::GetRegisteredKernelsForOp(name);
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
status->status = MessageToBuffer(kernel_list, ret);
|
||||
if (!status->status.ok()) {
|
||||
TF_DeleteBuffer(ret);
|
||||
return nullptr;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // end extern "C"
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
// * size_t is used to represent byte sizes of objects that are
|
||||
// materialized in the address space of the calling process.
|
||||
// * int is used as an index into arrays.
|
||||
// * Deletion functions are safe to call on nullptr.
|
||||
//
|
||||
// Questions left to address:
|
||||
// * Might at some point need a way for callers to provide their own Env.
|
||||
@ -1521,6 +1522,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
|
||||
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
|
||||
const TF_DeviceList* list, int index, TF_Status* status);
|
||||
|
||||
// Retrieve the incarnation number of a given device.
|
||||
//
|
||||
// If index is out of bounds, an error code will be set in the status object,
|
||||
// and 0 will be returned.
|
||||
TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation(
|
||||
const TF_DeviceList* list, int index, TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Load plugins containing custom ops and kernels
|
||||
|
||||
@ -1603,6 +1611,18 @@ TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map,
|
||||
size_t name_len,
|
||||
TF_Status* status);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Kernel definition information.
|
||||
|
||||
// Returns a serialized KernelList protocol buffer containing KernelDefs for all
|
||||
// registered kernels.
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status);
|
||||
|
||||
// Returns a serialized KernelList protocol buffer containing KernelDefs for all
|
||||
// kernels registered for the operation named `name`.
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp(
|
||||
const char* name, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -57,6 +57,33 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
|
||||
}
|
||||
}
|
||||
|
||||
TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
|
||||
unsigned char gpu_memory_allow_growth) {
|
||||
tensorflow::ConfigProto config;
|
||||
auto* optimizer_options =
|
||||
config.mutable_graph_options()->mutable_optimizer_options();
|
||||
if (enable_xla_compilation) {
|
||||
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
|
||||
|
||||
// These XLA flags are needed to trigger XLA properly from C (more generally
|
||||
// non-Python) clients. If this API is called again with `enable` set to
|
||||
// false, it is safe to keep these flag values as is.
|
||||
tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
|
||||
tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
|
||||
flags->tf_xla_cpu_global_jit = true;
|
||||
flags->tf_xla_min_cluster_size = 1;
|
||||
} else {
|
||||
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
|
||||
}
|
||||
|
||||
auto* gpu_options = config.mutable_gpu_options();
|
||||
gpu_options->set_allow_growth(gpu_memory_allow_growth);
|
||||
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
TF_CHECK_OK(MessageToBuffer(config, ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
|
||||
tensorflow::mutex_lock c(graph->mu);
|
||||
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
|
||||
|
@ -55,11 +55,21 @@ extern "C" {
|
||||
// set XLA flag values to prepare for XLA compilation. Otherwise set
|
||||
// global_jit_level to OFF.
|
||||
//
|
||||
// This API is syntax sugar over TF_SetConfig(), and is used by clients that
|
||||
// cannot read/write the tensorflow.ConfigProto proto.
|
||||
// This and the next API are syntax sugar over TF_SetConfig(), and is used by
|
||||
// clients that cannot read/write the tensorflow.ConfigProto proto.
|
||||
// TODO: Migrate to TF_CreateConfig() below.
|
||||
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
|
||||
unsigned char enable);
|
||||
|
||||
// Create a serialized tensorflow.ConfigProto proto, where:
|
||||
//
|
||||
// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
|
||||
// `enable_xla_compilation` is non-zero, and OFF otherwise.
|
||||
// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
|
||||
TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig(
|
||||
unsigned char enable_xla_compilation,
|
||||
unsigned char gpu_memory_allow_growth);
|
||||
|
||||
// Returns the graph content in a human-readable format, with length set in
|
||||
// `len`. The format is subject to change in the future.
|
||||
// The returned string is heap-allocated, and caller should call free() on it.
|
||||
|
@ -1516,7 +1516,8 @@ void DefineStatefulFunction(const char* name, TF_Function** func) {
|
||||
|
||||
TF_Output inputs[] = {};
|
||||
TF_Output outputs[] = {{random, 0}};
|
||||
*func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1,
|
||||
*func = TF_GraphToFunction(func_graph.get(), name,
|
||||
/*append_hash_to_fn_name=*/false, -1,
|
||||
/*opers=*/nullptr, 0, inputs, 1, outputs,
|
||||
/*output_names=*/nullptr,
|
||||
/*opts=*/nullptr, "", s.get());
|
||||
|
@ -29,9 +29,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/api_def.pb.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
||||
#include "tensorflow/core/framework/kernel_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb_text.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
@ -1424,6 +1426,29 @@ TEST(CAPI, SavedModelNullArgsAreValid) {
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI, DeletingNullPointerIsSafe) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
TF_DeleteStatus(nullptr);
|
||||
TF_DeleteBuffer(nullptr);
|
||||
TF_DeleteTensor(nullptr);
|
||||
TF_DeleteSessionOptions(nullptr);
|
||||
TF_DeleteGraph(nullptr);
|
||||
TF_DeleteImportGraphDefOptions(nullptr);
|
||||
TF_DeleteImportGraphDefResults(nullptr);
|
||||
TF_DeleteFunction(nullptr);
|
||||
TF_DeleteSession(nullptr, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeletePRunHandle(nullptr);
|
||||
TF_DeleteDeprecatedSession(nullptr, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteDeviceList(nullptr);
|
||||
TF_DeleteLibraryHandle(nullptr);
|
||||
TF_DeleteApiDefMap(nullptr);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
REGISTER_OP("TestOpWithNoGradient")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
@ -2312,6 +2337,57 @@ TEST(TestApiDef, TestCreateApiDefWithOverwrites) {
|
||||
TF_DeleteLibraryHandle(lib);
|
||||
}
|
||||
|
||||
class DummyKernel : public tensorflow::OpKernel {
|
||||
public:
|
||||
explicit DummyKernel(tensorflow::OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
void Compute(tensorflow::OpKernelContext* context) override {}
|
||||
};
|
||||
|
||||
// Test we can query kernels
|
||||
REGISTER_OP("TestOpWithSingleKernel")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float");
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("TestOpWithSingleKernel").Device(tensorflow::DEVICE_CPU), DummyKernel);
|
||||
|
||||
TEST(TestKernel, TestGetAllRegisteredKernels) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Buffer* kernel_list_buf = TF_GetAllRegisteredKernels(status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
KernelList kernel_list;
|
||||
kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
|
||||
ASSERT_GT(kernel_list.kernel_size(), 0);
|
||||
TF_DeleteBuffer(kernel_list_buf);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(TestKernel, TestGetRegisteredKernelsForOp) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Buffer* kernel_list_buf =
|
||||
TF_GetRegisteredKernelsForOp("TestOpWithSingleKernel", status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
KernelList kernel_list;
|
||||
kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
|
||||
ASSERT_EQ(kernel_list.kernel_size(), 1);
|
||||
EXPECT_EQ(kernel_list.kernel(0).op(), "TestOpWithSingleKernel");
|
||||
EXPECT_EQ(kernel_list.kernel(0).device_type(), "CPU");
|
||||
TF_DeleteBuffer(kernel_list_buf);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Buffer* kernel_list_buf = TF_GetRegisteredKernelsForOp("Unknown", status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
KernelList kernel_list;
|
||||
kernel_list.ParseFromArray(kernel_list_buf->data, kernel_list_buf->length);
|
||||
ASSERT_EQ(kernel_list.kernel_size(), 0);
|
||||
TF_DeleteBuffer(kernel_list_buf);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
#undef EXPECT_TF_META
|
||||
|
||||
} // namespace
|
||||
|
@ -664,17 +664,17 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
|
||||
|
||||
const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
|
||||
TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::Device* d = nullptr;
|
||||
tensorflow::Device* op_device = nullptr;
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
if (d != nullptr) {
|
||||
if (!h->handle->OnHostCPU()) {
|
||||
status->status = tensorflow::errors::FailedPrecondition(
|
||||
"TFE_TensorHandle is placed in device (not host) memory. Cannot return "
|
||||
"a tensorflow::Tensor");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = nullptr;
|
||||
tensorflow::Device* op_device = nullptr;
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return t;
|
||||
}
|
||||
|
||||
|
@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
||||
}
|
||||
} else {
|
||||
any_gradient_nonzero = true;
|
||||
auto new_gradients = vspace.AggregateGradients(grad_it->second);
|
||||
Gradient* new_gradients = nullptr;
|
||||
if (grad_it->second.size() == 1) {
|
||||
new_gradients = grad_it->second.at(0);
|
||||
} else {
|
||||
new_gradients = vspace.AggregateGradients(grad_it->second);
|
||||
}
|
||||
if (sources_set.find(grad_it->first) == sources_set.end()) {
|
||||
gradients.erase(grad_it);
|
||||
} else {
|
||||
|
@ -155,7 +155,7 @@ void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
|
||||
tensorflow::shape_inference::ShapeHandle shape;
|
||||
status->status =
|
||||
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
|
||||
if (status->status.ok()) return;
|
||||
if (!status->status.ok()) return;
|
||||
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
|
||||
}
|
||||
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
|
||||
|
@ -47,6 +47,72 @@ Status SoftmaxGrad(const Scope& scope, const Operation& op,
|
||||
}
|
||||
REGISTER_GRADIENT_OP("Softmax", SoftmaxGrad);
|
||||
|
||||
bool IsZero(const Scope& scope, const Output& grad) {
|
||||
string op_type_name = grad.op().node()->type_string();
|
||||
if (op_type_name == "ZerosLike" || op_type_name == "Zeros") {
|
||||
return true;
|
||||
}
|
||||
// The Operation we were provided is not named something obvious so
|
||||
// we need to actually look at its contents.
|
||||
// The original python code did this by calling a utility function called
|
||||
// tensor_util.constant_value.
|
||||
// There is no C++ equivalent to tensor_util.constant_value so we do nothing
|
||||
// for the moment.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Multiply after broadcasting vec to match dimensions of mat.
|
||||
// Args:
|
||||
// vec: A 1-D tensor of dimension [D0]
|
||||
// mat: A 2-D tensor of dimesnion [D0, D1]
|
||||
//
|
||||
// Returns:
|
||||
// A tensor of dimension [D0, D1], the result fo vec * mat.
|
||||
Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) {
|
||||
auto reshaped = ExpandDims(scope, vec, -1);
|
||||
return Multiply(scope, reshaped, mat);
|
||||
}
|
||||
|
||||
Status SoftmaxCrossEntropyWithLogitsGrad(const Scope& scope,
|
||||
const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
// Softmax gradient with cross entropy logits function.
|
||||
// We multiply the backprop for cost with the gradients - op.output[1].
|
||||
// There is no gradient for labels.
|
||||
|
||||
// The outputs of the network are at input index 0.
|
||||
auto logits = op.input(0);
|
||||
// The "truth" labels are at index 1.
|
||||
auto softmax_grad = op.output(1);
|
||||
|
||||
// The loss is the output at index 0, and backprop is the output at index 1.
|
||||
auto grad_loss = grad_inputs[0];
|
||||
auto grad_grad = grad_inputs[1];
|
||||
|
||||
auto grad = BroadcastMul(scope, grad_loss, softmax_grad);
|
||||
if (!IsZero(scope, grad_grad)) {
|
||||
std::vector<int> axis;
|
||||
auto logits_softmax = Softmax(scope, logits);
|
||||
|
||||
auto grad_grad_expand = ExpandDims(scope, grad_grad, 1);
|
||||
auto logits_softmax_expand = ExpandDims(scope, logits_softmax, 2);
|
||||
auto matmul_result =
|
||||
BatchMatMul(scope, grad_grad_expand, logits_softmax_expand);
|
||||
axis.push_back(1);
|
||||
auto squeeze_result = Squeeze(scope, matmul_result, Squeeze::Axis(axis));
|
||||
auto subtraction_result = Subtract(scope, grad_grad, squeeze_result);
|
||||
auto multiply_result = Multiply(scope, subtraction_result, logits_softmax);
|
||||
grad = Add(scope, grad, multiply_result);
|
||||
}
|
||||
auto minus_log_softmax = Multiply(scope, LogSoftmax(scope, logits), -1.0f);
|
||||
grad_outputs->push_back(grad);
|
||||
grad_outputs->push_back(BroadcastMul(scope, grad_loss, minus_log_softmax));
|
||||
return scope.status();
|
||||
}
|
||||
REGISTER_GRADIENT_OP("SoftmaxCrossEntropyWithLogits",
|
||||
SoftmaxCrossEntropyWithLogitsGrad);
|
||||
|
||||
Status LogSoftmaxGrad(const Scope& scope, const Operation& op,
|
||||
const std::vector<Output>& grad_inputs,
|
||||
std::vector<Output>* grad_outputs) {
|
||||
@ -195,9 +261,9 @@ Status MaxPool3DGradHelper(const Scope& scope, const Operation& op,
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
|
||||
MaxPool3DGrad::Attrs grad_attrs;
|
||||
auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0],
|
||||
ksize, strides, padding,
|
||||
grad_attrs.DataFormat(data_format));
|
||||
auto dx =
|
||||
MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], ksize,
|
||||
strides, padding, grad_attrs.DataFormat(data_format));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
@ -216,10 +282,9 @@ Status AvgPoolGradHelper(const Scope& scope, const Operation& op,
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
|
||||
internal::AvgPoolGrad::Attrs grad_attrs;
|
||||
auto dx =
|
||||
internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
|
||||
ksize, strides, padding,
|
||||
grad_attrs.DataFormat(data_format));
|
||||
auto dx = internal::AvgPoolGrad(scope, Shape(scope, op.input(0)),
|
||||
grad_inputs[0], ksize, strides, padding,
|
||||
grad_attrs.DataFormat(data_format));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
@ -238,9 +303,9 @@ Status AvgPool3DGradHelper(const Scope& scope, const Operation& op,
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format));
|
||||
AvgPool3DGrad::Attrs grad_attrs;
|
||||
auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0],
|
||||
ksize, strides, padding,
|
||||
grad_attrs.DataFormat(data_format));
|
||||
auto dx =
|
||||
AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], ksize,
|
||||
strides, padding, grad_attrs.DataFormat(data_format));
|
||||
grad_outputs->push_back(dx);
|
||||
return scope.status();
|
||||
}
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using ops::AvgPool;
|
||||
using ops::AvgPool3D;
|
||||
using ops::BiasAdd;
|
||||
using ops::Conv2D;
|
||||
using ops::Elu;
|
||||
@ -33,11 +35,9 @@ using ops::FractionalMaxPool;
|
||||
using ops::L2Loss;
|
||||
using ops::LogSoftmax;
|
||||
using ops::LRN;
|
||||
using ops::AvgPool;
|
||||
using ops::AvgPool3D;
|
||||
using ops::MaxPool;
|
||||
using ops::MaxPoolV2;
|
||||
using ops::MaxPool3D;
|
||||
using ops::MaxPoolV2;
|
||||
using ops::Placeholder;
|
||||
using ops::Relu;
|
||||
using ops::Relu6;
|
||||
@ -111,6 +111,20 @@ TEST_F(NNGradTest, SoftmaxGrad) {
|
||||
RunTest(x, shape, y, shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, SoftmaxCrossEntropyWithLogitsGrad) {
|
||||
TensorShape logits_shape({5, 3});
|
||||
TensorShape loss_shape({5});
|
||||
|
||||
auto logits = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
|
||||
auto labels = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(logits_shape));
|
||||
auto y =
|
||||
tensorflow::ops::SoftmaxCrossEntropyWithLogits(scope_, logits, labels);
|
||||
// Note the reversal of the backprop and loss orders. Issue #18734 has been
|
||||
// opened for this.
|
||||
RunTest({logits, labels}, {logits_shape, logits_shape}, {y.backprop, y.loss},
|
||||
{logits_shape, loss_shape});
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, LogSoftmaxGrad) {
|
||||
TensorShape shape({5, 3});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
|
||||
@ -253,7 +267,7 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) {
|
||||
RunTest(x, x_shape, y, y_shape);
|
||||
}
|
||||
|
||||
TEST_F(NNGradTest, LRN){
|
||||
TEST_F(NNGradTest, LRN) {
|
||||
TensorShape x_shape({1, 1, 2, 1});
|
||||
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
|
||||
auto y = LRN(scope_, x);
|
||||
|
@ -33,6 +33,46 @@ cc_library(
|
||||
hdrs = ["tag_constants.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "reader",
|
||||
srcs = ["reader.cc"],
|
||||
hdrs = ["reader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
] + if_not_mobile([
|
||||
# TODO(b/111634734): :lib and :protos_all contain dependencies that
|
||||
# cannot be built on mobile platforms. Instead, include the appropriate
|
||||
# tf_lib depending on the build platform.
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]) + if_mobile([
|
||||
# Mobile-friendly SavedModel proto. See go/portable-proto for more info.
|
||||
"//tensorflow/core:saved_model_portable_proto",
|
||||
]) + if_android([
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
]) + if_ios([
|
||||
"//tensorflow/core:ios_tensorflow_lib",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "reader_test",
|
||||
srcs = ["reader_test.cc"],
|
||||
data = [
|
||||
":saved_model_half_plus_two",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":constants",
|
||||
":reader",
|
||||
":tag_constants",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loader",
|
||||
hdrs = ["loader.h"],
|
||||
@ -54,6 +94,7 @@ cc_library(
|
||||
hdrs = ["loader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
":reader",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New(
|
||||
constexpr char kLoadAttemptFail[] = "fail";
|
||||
constexpr char kLoadAttemptSuccess[] = "success";
|
||||
|
||||
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||
const string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
|
||||
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
|
||||
saved_model_proto);
|
||||
}
|
||||
const string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
|
||||
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||
saved_model_proto);
|
||||
}
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find SavedModel .pb or .pbtxt at supplied export "
|
||||
"directory path: " +
|
||||
export_dir);
|
||||
}
|
||||
|
||||
string GetTagsAsString(const std::unordered_set<string>& tags) {
|
||||
string tags_as_string = "{ ";
|
||||
for (const string& tag : tags) {
|
||||
tags_as_string = strings::StrCat(tags_as_string, tag, " ");
|
||||
}
|
||||
tags_as_string = strings::StrCat(tags_as_string, "}");
|
||||
return tags_as_string;
|
||||
}
|
||||
|
||||
Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* meta_graph_def_to_load) {
|
||||
for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
|
||||
// Get tags from the meta_graph_def.
|
||||
std::unordered_set<string> graph_tags;
|
||||
for (const string& tag : meta_graph_def.meta_info_def().tags()) {
|
||||
graph_tags.insert(tag);
|
||||
}
|
||||
// Match with the set of tags provided.
|
||||
if (graph_tags == tags) {
|
||||
*meta_graph_def_to_load = meta_graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find meta graph def matching supplied tags: " +
|
||||
GetTagsAsString(tags) +
|
||||
". To inspect available tag-sets in the SavedModel, please "
|
||||
"use the SavedModel CLI: `saved_model_cli`");
|
||||
}
|
||||
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
const SessionOptions& session_options,
|
||||
std::unique_ptr<Session>* session) {
|
||||
@ -134,10 +86,11 @@ bool HasMainOp(const MetaGraphDef& meta_graph_def) {
|
||||
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
|
||||
const MetaGraphDef& meta_graph_def,
|
||||
const std::vector<AssetFileDef>& asset_file_defs,
|
||||
Session* session) {
|
||||
LOG(INFO) << "Running MainOp on SavedModel bundle.";
|
||||
Session* session, const string& main_op_key) {
|
||||
LOG(INFO) << "Running MainOp with key " << main_op_key
|
||||
<< " on SavedModel bundle.";
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto main_op_it = collection_def_map.find(kSavedModelMainOpKey);
|
||||
const auto main_op_it = collection_def_map.find(main_op_key);
|
||||
if (main_op_it != collection_def_map.end()) {
|
||||
if (main_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
@ -189,30 +142,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata);
|
||||
}
|
||||
|
||||
Status RunLegacyInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
const MetaGraphDef& meta_graph_def,
|
||||
const std::vector<AssetFileDef>& asset_file_defs,
|
||||
Session* session) {
|
||||
LOG(INFO) << "Running LegacyInitOp on SavedModel bundle.";
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto init_op_it = collection_def_map.find(kSavedModelLegacyInitOpKey);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(strings::StrCat(
|
||||
"Expected exactly one serving init op in : ", export_dir));
|
||||
}
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
||||
RunMetadata run_metadata;
|
||||
const StringPiece legacy_init_op_name =
|
||||
init_op_it->second.node_list().value(0);
|
||||
return session->Run(run_options, inputs, {},
|
||||
{legacy_init_op_name.ToString()}, nullptr /* outputs */,
|
||||
&run_metadata);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
@ -235,18 +164,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
SavedModelBundle* const bundle) {
|
||||
if (!MaybeSavedModelDirectory(export_dir)) {
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"SavedModel not found in export directory: " + export_dir);
|
||||
}
|
||||
LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags)
|
||||
<< "; from: " << export_dir;
|
||||
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
|
||||
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
|
||||
&bundle->meta_graph_def));
|
||||
|
||||
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
||||
bundle->meta_graph_def, session_options, &bundle->session));
|
||||
@ -262,11 +181,11 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
if (HasMainOp(bundle->meta_graph_def)) {
|
||||
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
|
||||
bundle->meta_graph_def, asset_file_defs,
|
||||
bundle->session.get()));
|
||||
bundle->session.get(), kSavedModelMainOpKey));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(RunLegacyInitOp(run_options, export_dir,
|
||||
bundle->meta_graph_def, asset_file_defs,
|
||||
bundle->session.get()));
|
||||
TF_RETURN_IF_ERROR(RunMainOp(
|
||||
run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
|
||||
bundle->session.get(), kSavedModelLegacyInitOpKey));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -288,8 +207,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
||||
return end_microseconds - start_microseconds;
|
||||
}();
|
||||
auto log_and_count = [&](const string& status_str) {
|
||||
LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags)
|
||||
<< "; Status: " << status_str << ". Took "
|
||||
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
|
||||
<< " }; Status: " << status_str << ". Took "
|
||||
<< load_latency_microsecs << " microseconds.";
|
||||
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
||||
};
|
||||
|
88
tensorflow/cc/saved_model/reader.cc
Normal file
88
tensorflow/cc/saved_model/reader.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||
LOG(INFO) << "Reading SavedModel from: " << export_dir;
|
||||
|
||||
const string saved_model_pb_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
|
||||
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
|
||||
saved_model_proto);
|
||||
}
|
||||
const string saved_model_pbtxt_path =
|
||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
|
||||
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||
saved_model_proto);
|
||||
}
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Could not find SavedModel .pb or .pbtxt at supplied export "
|
||||
"directory path: " +
|
||||
export_dir);
|
||||
}
|
||||
|
||||
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* meta_graph_def) {
|
||||
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
|
||||
<< " }";
|
||||
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||
// Get tags from the graph_def.
|
||||
std::unordered_set<string> graph_tags;
|
||||
for (const string& tag : graph_def.meta_info_def().tags()) {
|
||||
graph_tags.insert(tag);
|
||||
}
|
||||
// Match with the set of tags provided.
|
||||
if (graph_tags == tags) {
|
||||
*meta_graph_def = graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
return Status(
|
||||
error::Code::NOT_FOUND,
|
||||
strings::StrCat(
|
||||
"Could not find meta graph def matching supplied tags: { ",
|
||||
str_util::Join(tags, " "),
|
||||
" }. To inspect available tag-sets in the SavedModel, please "
|
||||
"use the SavedModel CLI: `saved_model_cli`"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* const meta_graph_def) {
|
||||
SavedModel saved_model_proto;
|
||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||
TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
39
tensorflow/cc/saved_model/reader.h
Normal file
39
tensorflow/cc/saved_model/reader.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
/// Functions to read the SavedModel proto, or parts of it.
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
|
||||
// finds the MetaGraphDef that matches the given set of tags and writes it to
|
||||
// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
|
||||
// file does not exist or no MetaGraphDef matches the tags.
|
||||
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
MetaGraphDef* const meta_graph_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
108
tensorflow/cc/saved_model/reader_test.cc
Normal file
108
tensorflow/cc/saved_model/reader_test.cc
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestDataPbTxt[] =
|
||||
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
|
||||
class ReaderTest : public ::testing::Test {
|
||||
protected:
|
||||
ReaderTest() {}
|
||||
|
||||
void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
|
||||
const auto& tags = meta_graph_def.meta_info_def().tags();
|
||||
EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
|
||||
tags.end());
|
||||
EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
|
||||
EXPECT_EQ(
|
||||
meta_graph_def.signature_def().at("serving_default").method_name(),
|
||||
"tensorflow/serving/predict");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ReaderTest, TagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, NoTagMatch) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
Status st = ReadMetaGraphDefFromSavedModel(
|
||||
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
st.error_message(),
|
||||
"Could not find meta graph def matching supplied tags: "))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, PbtxtFormat) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def));
|
||||
CheckMetaGraphDef(meta_graph_def);
|
||||
}
|
||||
|
||||
TEST_F(ReaderTest, InvalidExportPath) {
|
||||
MetaGraphDef meta_graph_def;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||
&meta_graph_def);
|
||||
EXPECT_FALSE(st.ok());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -68,6 +68,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:compile_only_client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/compile_only_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
@ -176,9 +176,11 @@ cc_library(
|
||||
"//tensorflow/core/kernels:cast_op",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:fifo_queue",
|
||||
"//tensorflow/core/kernels:identity_n_op",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:no_op",
|
||||
"//tensorflow/core/kernels:queue_op",
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
@ -302,11 +304,13 @@ cc_library(
|
||||
name = "compilation_passes",
|
||||
srcs = [
|
||||
"build_xla_launch_ops_pass.cc",
|
||||
"deadness_analysis.cc",
|
||||
"encapsulate_subgraphs_pass.cc",
|
||||
"mark_for_compilation_pass.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"build_xla_launch_ops_pass.h",
|
||||
"deadness_analysis.h",
|
||||
"encapsulate_subgraphs_pass.h",
|
||||
"mark_for_compilation_pass.h",
|
||||
],
|
||||
@ -323,6 +327,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -375,6 +380,7 @@ tf_cc_test(
|
||||
name = "compilation_passes_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"deadness_analysis_test.cc",
|
||||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"mark_for_compilation_pass_test.cc",
|
||||
],
|
||||
@ -385,6 +391,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_launch_op",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
@ -456,6 +463,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":common",
|
||||
":compilation_passes",
|
||||
":union_find",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
|
566
tensorflow/compiler/jit/deadness_analysis.cc
Normal file
566
tensorflow/compiler/jit/deadness_analysis.cc
Normal file
@ -0,0 +1,566 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
// ALGORITHM OVERVIEW
|
||||
//
|
||||
// We map every output produced by each node in the TensorFlow graph (including
|
||||
// control dependence) into an instance of the Predicate class. Instances of
|
||||
// Predicate denote logical formulas and mapping a node `n` to a predicate
|
||||
// `pred` implies that `n` is executed whenver `pred` is true. Then we can
|
||||
// deduce mismatching liveness in the inputs to node by comparing the predicate
|
||||
// those inputs are mapped to.
|
||||
//
|
||||
// Loops are handled pessimistically -- we map Merge nodes with backedges to
|
||||
// uninterpreted symbols (the same kind we use to represent Switch and _Recv).
|
||||
// Predicate equality has to hold over all possible assignments to these
|
||||
// uninterpreted symbols.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Represents a logical predicate, used as described in the algorithm overview
|
||||
// above.
|
||||
class Predicate {
|
||||
public:
|
||||
enum class Kind { kAnd, kOr, kNot, kSymbol };
|
||||
|
||||
virtual string ToString() const = 0;
|
||||
int64 hash() const { return hash_; }
|
||||
|
||||
virtual Kind kind() const = 0;
|
||||
virtual ~Predicate() {}
|
||||
|
||||
protected:
|
||||
explicit Predicate(int64 hash) : hash_(hash) {}
|
||||
|
||||
private:
|
||||
const int64 hash_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
|
||||
};
|
||||
|
||||
int64 HashPredicateSequence(Predicate::Kind kind,
|
||||
gtl::ArraySlice<Predicate*> preds) {
|
||||
int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
|
||||
for (Predicate* pred : preds) {
|
||||
hash = Hash64Combine(hash, pred->hash());
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
// Represents a logical conjunction of a set of predicates.
|
||||
class AndPredicate : public Predicate {
|
||||
public:
|
||||
explicit AndPredicate(std::vector<Predicate*> operands)
|
||||
: Predicate(HashPredicateSequence(Kind::kAnd, operands)),
|
||||
operands_(std::move(operands)) {}
|
||||
|
||||
string ToString() const override {
|
||||
if (operands().empty()) {
|
||||
return "#true";
|
||||
}
|
||||
|
||||
std::vector<string> operands_str;
|
||||
std::transform(operands().begin(), operands().end(),
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", str_util::Join(operands_str, " & "), ")");
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kAnd; }
|
||||
|
||||
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
};
|
||||
|
||||
// Represents a logical disjunction of a set of predicates.
|
||||
class OrPredicate : public Predicate {
|
||||
public:
|
||||
explicit OrPredicate(std::vector<Predicate*> operands)
|
||||
: Predicate(HashPredicateSequence(Kind::kOr, operands)),
|
||||
operands_(std::move(operands)) {}
|
||||
|
||||
string ToString() const override {
|
||||
if (operands().empty()) {
|
||||
return "#false";
|
||||
}
|
||||
|
||||
std::vector<string> operands_str;
|
||||
std::transform(operands().begin(), operands().end(),
|
||||
std::back_inserter(operands_str),
|
||||
[](Predicate* pred) { return pred->ToString(); });
|
||||
|
||||
return strings::StrCat("(", str_util::Join(operands_str, " | "), ")");
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kOr; }
|
||||
const gtl::ArraySlice<Predicate*> operands() const { return operands_; }
|
||||
|
||||
private:
|
||||
std::vector<Predicate*> operands_;
|
||||
};
|
||||
|
||||
// Represents a logical negation of a set of predicates.
|
||||
class NotPredicate : public Predicate {
|
||||
public:
|
||||
explicit NotPredicate(Predicate* operand)
|
||||
: Predicate(HashPredicateSequence(Kind::kNot, {operand})),
|
||||
operand_(operand) {}
|
||||
|
||||
string ToString() const override {
|
||||
return strings::StrCat("~", operand()->ToString());
|
||||
}
|
||||
|
||||
Kind kind() const override { return Kind::kNot; }
|
||||
Predicate* operand() const { return operand_; }
|
||||
|
||||
private:
|
||||
Predicate* operand_;
|
||||
};
|
||||
|
||||
// Represents an uninterpreted symbol in a logical predicate.
|
||||
//
|
||||
// Two predicates are equivalent iff they are equivalent for all assignments to
|
||||
// the symbols contained in them.
|
||||
class SymbolPredicate : public Predicate {
|
||||
public:
|
||||
explicit SymbolPredicate(TensorId tensor_id, bool must_be_true)
|
||||
: Predicate(Hash(tensor_id, must_be_true)),
|
||||
tensor_id_(std::move(tensor_id)),
|
||||
must_be_true_(must_be_true) {}
|
||||
|
||||
string ToString() const override { return tensor_id_.ToString(); }
|
||||
Kind kind() const override { return Kind::kSymbol; }
|
||||
|
||||
// If `must_be_true()` is true this SymbolPredicate represents the proposition
|
||||
// "tensor_id() is live and evaluates to true".
|
||||
//
|
||||
// If `must_be_true()` is false then this SymbolPredicate represents the
|
||||
// proposition "tensor_id() is live (and may evalutate to any value)"
|
||||
TensorId tensor_id() const { return tensor_id_; }
|
||||
bool must_be_true() const { return must_be_true_; }
|
||||
|
||||
private:
|
||||
TensorId tensor_id_;
|
||||
bool must_be_true_;
|
||||
|
||||
static int64 Hash(const TensorId tensor_id, bool must_be_true) {
|
||||
return Hash64Combine(
|
||||
::tensorflow::hash<bool>()(must_be_true),
|
||||
Hash64Combine(::tensorflow::hash<Predicate::Kind>()(Kind::kSymbol),
|
||||
TensorId::Hasher{}(tensor_id)));
|
||||
}
|
||||
};
|
||||
|
||||
// Creates and owns Predicate instances. Simplifies predicates as it creates
|
||||
// them.
|
||||
class PredicateFactory {
|
||||
public:
|
||||
Predicate* MakeAndPredicate(gtl::ArraySlice<Predicate*> operands) {
|
||||
return MakeAndOrImpl(operands, /*is_and=*/true);
|
||||
}
|
||||
|
||||
Predicate* MakeOrPredicate(gtl::ArraySlice<Predicate*> operands) {
|
||||
return MakeAndOrImpl(operands, /*is_and=*/false);
|
||||
}
|
||||
|
||||
Predicate* MakeNotPredicate(Predicate* pred) {
|
||||
SignatureForNot signature = pred;
|
||||
auto it = interned_not_instances_.find(signature);
|
||||
if (it == interned_not_instances_.end()) {
|
||||
std::unique_ptr<Predicate> new_pred = Make<NotPredicate>(pred);
|
||||
Predicate* new_pred_ptr = new_pred.get();
|
||||
interned_not_instances_.emplace(signature, std::move(new_pred));
|
||||
return new_pred_ptr;
|
||||
} else {
|
||||
return it->second.get();
|
||||
}
|
||||
}
|
||||
|
||||
Predicate* MakeSymbolPredicate(TensorId tensor_id, bool must_be_true) {
|
||||
SignatureForSymbol signature = {tensor_id, must_be_true};
|
||||
auto it = interned_symbol_instances_.find(signature);
|
||||
if (it == interned_symbol_instances_.end()) {
|
||||
std::unique_ptr<Predicate> new_pred =
|
||||
Make<SymbolPredicate>(tensor_id, must_be_true);
|
||||
Predicate* new_pred_ptr = new_pred.get();
|
||||
interned_symbol_instances_.emplace(std::move(signature),
|
||||
std::move(new_pred));
|
||||
return new_pred_ptr;
|
||||
} else {
|
||||
return it->second.get();
|
||||
}
|
||||
}
|
||||
|
||||
Predicate* MakeTrue() { return MakeAndPredicate({}); }
|
||||
Predicate* MakeFalse() { return MakeOrPredicate({}); }
|
||||
|
||||
private:
|
||||
template <typename PredicateT, typename... Args>
|
||||
std::unique_ptr<Predicate> Make(Args&&... args) {
|
||||
return std::unique_ptr<PredicateT>(
|
||||
new PredicateT(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
Predicate* MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands, bool is_and);
|
||||
|
||||
// Predicate instances are interned, meaning that there is only a single
|
||||
// instance of a Predicate object with a given content. This makes checking
|
||||
// for structural equality super-cheap -- we can just compare pointers.
|
||||
//
|
||||
// We intern predicates by maintaining a map from the content of a Predicate
|
||||
// to the only instance of said predicate we allow to exist in the
|
||||
// interned_and_or_instances_, interned_not_instances_ and
|
||||
// interned_symbol_instances_ fields. These maps also double up as storage
|
||||
// for the owning pointers to predicate instances.
|
||||
|
||||
using SignatureForAndOr =
|
||||
std::pair<Predicate::Kind, gtl::ArraySlice<Predicate*>>;
|
||||
using SignatureForNot = Predicate*;
|
||||
using SignatureForSymbol = std::pair<SafeTensorId, bool>;
|
||||
|
||||
struct HashSignatureForAndOr {
|
||||
size_t operator()(const SignatureForAndOr& signature) const {
|
||||
size_t hash = ::tensorflow::hash<Predicate::Kind>()(signature.first);
|
||||
for (Predicate* p : signature.second) {
|
||||
hash = Hash64Combine(hash, ::tensorflow::hash<Predicate*>()(p));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
};
|
||||
|
||||
struct HashSignatureForSymbol {
|
||||
size_t operator()(const SignatureForSymbol& signature) const {
|
||||
return Hash64Combine(SafeTensorId::Hasher()(signature.first),
|
||||
::tensorflow::hash<bool>()(signature.second));
|
||||
}
|
||||
};
|
||||
|
||||
gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
|
||||
HashSignatureForAndOr>
|
||||
interned_and_or_instances_;
|
||||
gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
|
||||
interned_not_instances_;
|
||||
gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
|
||||
HashSignatureForSymbol>
|
||||
interned_symbol_instances_;
|
||||
};
|
||||
|
||||
// Common code to create AndPredicate or OrPredicate instances.
|
||||
Predicate* PredicateFactory::MakeAndOrImpl(gtl::ArraySlice<Predicate*> operands,
|
||||
bool is_and) {
|
||||
Predicate::Kind pred_kind =
|
||||
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
|
||||
gtl::FlatSet<Predicate*> simplified_ops_set;
|
||||
std::vector<Predicate*> simplified_ops;
|
||||
for (Predicate* op : operands) {
|
||||
// Simplify A&A => A and A|A => A.
|
||||
if (!simplified_ops_set.insert(op).second) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (op->kind() == pred_kind) {
|
||||
// "Inline" the operands of an inner And/Or into the parent And/Or.
|
||||
gtl::ArraySlice<Predicate*> operands =
|
||||
is_and ? dynamic_cast<AndPredicate*>(op)->operands()
|
||||
: dynamic_cast<OrPredicate*>(op)->operands();
|
||||
for (Predicate* subop : operands) {
|
||||
if (simplified_ops_set.insert(subop).second) {
|
||||
simplified_ops.push_back(subop);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
simplified_ops.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
if (simplified_ops.size() == 1) {
|
||||
return simplified_ops[0];
|
||||
}
|
||||
|
||||
// Simplify "A&~A=>False" and "A|~A=>True".
|
||||
gtl::FlatSet<Predicate*> negated_ops;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kNot) {
|
||||
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
|
||||
}
|
||||
}
|
||||
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (negated_ops.count(op)) {
|
||||
return is_and ? MakeFalse() : MakeTrue();
|
||||
}
|
||||
}
|
||||
|
||||
std::stable_sort(
|
||||
simplified_ops.begin(), simplified_ops.end(),
|
||||
[](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
|
||||
|
||||
auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
|
||||
if (it == interned_and_or_instances_.end()) {
|
||||
simplified_ops.shrink_to_fit();
|
||||
// NB! Because we'll use a non-owning reference to simplified_ops in the
|
||||
// key for interned_and_or_instances_ we need to be careful to std::move()
|
||||
// it all the way through.
|
||||
gtl::ArraySlice<Predicate*> operands_slice = simplified_ops;
|
||||
std::unique_ptr<Predicate> new_pred =
|
||||
is_and ? Make<AndPredicate>(std::move(simplified_ops))
|
||||
: Make<OrPredicate>(std::move(simplified_ops));
|
||||
|
||||
Predicate* new_pred_ptr = new_pred.get();
|
||||
CHECK(interned_and_or_instances_
|
||||
.emplace(SignatureForAndOr(pred_kind, operands_slice),
|
||||
std::move(new_pred))
|
||||
.second);
|
||||
return new_pred_ptr;
|
||||
} else {
|
||||
return it->second.get();
|
||||
}
|
||||
}
|
||||
|
||||
class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
public:
|
||||
explicit DeadnessAnalysisImpl(const Graph* graph)
|
||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||
|
||||
Status Populate();
|
||||
bool HasInputsWithMismatchingDeadness(const Node& node) override;
|
||||
void Print() const override;
|
||||
|
||||
private:
|
||||
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
|
||||
|
||||
std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
|
||||
void SetPred(Node* n, int output_idx, Predicate* pred) {
|
||||
CHECK(
|
||||
predicate_map_.insert({TensorId(n->name(), output_idx), pred}).second);
|
||||
}
|
||||
void SetPred(Node* n, gtl::ArraySlice<int> output_idxs, Predicate* pred) {
|
||||
for (int output_idx : output_idxs) {
|
||||
SetPred(n, output_idx, pred);
|
||||
}
|
||||
}
|
||||
|
||||
Status HandleSwitch(Node* n);
|
||||
Status HandleMerge(Node* n);
|
||||
Status HandleRecv(Node* n);
|
||||
Status HandleGeneric(Node* n);
|
||||
|
||||
const Graph& graph_;
|
||||
gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
||||
PredicateFactory predicate_factory_;
|
||||
bool vlog_;
|
||||
};
|
||||
|
||||
TensorId InputEdgeToTensorId(const Edge* e) {
|
||||
return TensorId(e->src()->name(), e->src_output());
|
||||
}
|
||||
|
||||
std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
|
||||
Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
|
||||
std::vector<Predicate*> incoming_preds;
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
bool should_process =
|
||||
edge_kind == EdgeKind::kDataAndControl ||
|
||||
(in_edge->IsControlEdge() && edge_kind == EdgeKind::kControlOnly) ||
|
||||
(!in_edge->IsControlEdge() && edge_kind == EdgeKind::kDataOnly);
|
||||
|
||||
if (should_process) {
|
||||
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
|
||||
CHECK(it != predicate_map_.end());
|
||||
incoming_preds.push_back(it->second);
|
||||
}
|
||||
}
|
||||
return incoming_preds;
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleSwitch(Node* n) {
|
||||
std::vector<Predicate*> input_preds =
|
||||
GetIncomingPreds(n, EdgeKind::kDataAndControl);
|
||||
const Edge* pred_edge;
|
||||
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
|
||||
Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
|
||||
TensorId(pred_edge->src()->name(), pred_edge->src_output()),
|
||||
/*must_be_true=*/true);
|
||||
Predicate* false_switch = predicate_factory_.MakeNotPredicate(true_switch);
|
||||
|
||||
// Output 0 is alive iff all inputs are alive and the condition is false.
|
||||
input_preds.push_back(false_switch);
|
||||
SetPred(n, 0, predicate_factory_.MakeAndPredicate(input_preds));
|
||||
input_preds.pop_back();
|
||||
|
||||
// Output 1 is alive iff all inputs are alive and the condition is true.
|
||||
input_preds.push_back(true_switch);
|
||||
SetPred(n, 1, predicate_factory_.MakeAndPredicate(input_preds));
|
||||
input_preds.pop_back();
|
||||
|
||||
// Control is alive iff any inputs are alive.
|
||||
SetPred(n, Graph::kControlSlot,
|
||||
predicate_factory_.MakeAndPredicate(input_preds));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleMerge(Node* n) {
|
||||
// Merge ignores deadness of its control inputs. A merge that isn't the
|
||||
// target of a backedge has is alive iff any of its data inputs are. We treat
|
||||
// the liveness of a merge that is the target of a backedge symbolically.
|
||||
|
||||
bool has_backedge = std::any_of(
|
||||
n->in_edges().begin(), n->in_edges().end(), [](const Edge* e) {
|
||||
return !e->IsControlEdge() && e->src()->IsNextIteration();
|
||||
});
|
||||
|
||||
Predicate* input_data_pred =
|
||||
has_backedge ? predicate_factory_.MakeSymbolPredicate(
|
||||
TensorId(n->name(), 0), /*must_be_true=*/false)
|
||||
: predicate_factory_.MakeOrPredicate(
|
||||
GetIncomingPreds(n, EdgeKind::kDataOnly));
|
||||
|
||||
SetPred(n, {0, 1, Graph::kControlSlot}, input_data_pred);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleRecv(Node* n) {
|
||||
// In addition to being alive or dead based on the inputs, a _Recv can also
|
||||
// acquire a dead signal from a _Send.
|
||||
std::vector<Predicate*> input_preds =
|
||||
GetIncomingPreds(n, EdgeKind::kDataAndControl);
|
||||
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
|
||||
TensorId(n->name(), 0), /*must_be_true=*/false));
|
||||
SetPred(n, {0, Graph::kControlSlot},
|
||||
predicate_factory_.MakeAndPredicate(input_preds));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleGeneric(Node* n) {
|
||||
// Generally nodes are alive iff all their inputs are alive.
|
||||
Predicate* pred = predicate_factory_.MakeAndPredicate(
|
||||
GetIncomingPreds(n, EdgeKind::kDataAndControl));
|
||||
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
|
||||
SetPred(n, output_idx, pred);
|
||||
}
|
||||
SetPred(n, Graph::kControlSlot, pred);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::Populate() {
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/{},
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
|
||||
// This an abstract interpretation over the deadness propagation semantics of
|
||||
// the graph executor.
|
||||
for (Node* n : rpo) {
|
||||
if (n->IsSwitch()) {
|
||||
TF_RETURN_IF_ERROR(HandleSwitch(n));
|
||||
} else if (n->IsMerge()) {
|
||||
TF_RETURN_IF_ERROR(HandleMerge(n));
|
||||
} else if (n->IsControlTrigger()) {
|
||||
SetPred(n, Graph::kControlSlot, predicate_factory_.MakeTrue());
|
||||
} else if (n->IsRecv() || n->IsHostRecv()) {
|
||||
TF_RETURN_IF_ERROR(HandleRecv(n));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(HandleGeneric(n));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool DeadnessAnalysisImpl::HasInputsWithMismatchingDeadness(const Node& node) {
|
||||
CHECK(!node.IsMerge());
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name() << ")";
|
||||
}
|
||||
|
||||
Predicate* pred = nullptr;
|
||||
for (const Edge* edge : node.in_edges()) {
|
||||
auto it = predicate_map_.find(InputEdgeToTensorId(edge));
|
||||
CHECK(it != predicate_map_.end());
|
||||
if (vlog_) {
|
||||
VLOG(2) << " " << InputEdgeToTensorId(edge).ToString() << ": "
|
||||
<< it->second->ToString();
|
||||
}
|
||||
|
||||
// Today we just compare the predicates for equality (with some
|
||||
// canonicalization/simplification happening before) but we could be more
|
||||
// sophisticated here if need be. Comparing pointers is sufficient because
|
||||
// we intern Predicate instances by their content.
|
||||
if (pred != nullptr && pred != it->second) {
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
|
||||
<< ") -> true";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
pred = it->second;
|
||||
}
|
||||
|
||||
if (vlog_) {
|
||||
VLOG(2) << "HasInputsWithMismatchingDeadness(" << node.name()
|
||||
<< ") -> false";
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void DeadnessAnalysisImpl::Print() const {
|
||||
std::vector<TensorId> tensor_ids;
|
||||
for (const auto& kv_pair : predicate_map_) {
|
||||
tensor_ids.push_back(kv_pair.first);
|
||||
}
|
||||
|
||||
std::sort(tensor_ids.begin(), tensor_ids.end());
|
||||
|
||||
for (TensorId tensor_id : tensor_ids) {
|
||||
auto it = predicate_map_.find(tensor_id);
|
||||
CHECK(it != predicate_map_.end()) << tensor_id.ToString();
|
||||
VLOG(2) << tensor_id.ToString() << " -> " << it->second->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
DeadnessAnalysis::~DeadnessAnalysis() {}
|
||||
|
||||
/*static*/ Status DeadnessAnalysis::Run(
|
||||
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
|
||||
std::unique_ptr<DeadnessAnalysisImpl> analysis(
|
||||
new DeadnessAnalysisImpl(&graph));
|
||||
TF_RETURN_IF_ERROR(analysis->Populate());
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
analysis->Print();
|
||||
}
|
||||
|
||||
*result = std::move(analysis);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
68
tensorflow/compiler/jit/deadness_analysis.h
Normal file
68
tensorflow/compiler/jit/deadness_analysis.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
||||
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This analyzes a TensorFlow graph to identify nodes which may have partially
|
||||
// dead inputs (i.e. these nodes may have some dead inputs and some alive
|
||||
// inputs).
|
||||
//
|
||||
// For example, the ADD node in the following graph
|
||||
//
|
||||
// V0 PRED0 V1 PRED1
|
||||
// | | | |
|
||||
// v v v v
|
||||
// SWITCH SWITCH
|
||||
// | |
|
||||
// +---+ + ---+
|
||||
// | |
|
||||
// v v
|
||||
// ADD
|
||||
//
|
||||
// can have its inputs independently dead or alive based on the runtime values
|
||||
// of PRED0 and PRED1.
|
||||
//
|
||||
// It is tempting to call this a liveness analysis but I avoided that because
|
||||
// "liveness" already has other connotations.
|
||||
class DeadnessAnalysis {
|
||||
public:
|
||||
// Returns true if `node` may have some live inputs and some dead inputs.
|
||||
//
|
||||
// This is a conservatively correct routine -- if it returns false then `node`
|
||||
// is guaranteed to not have inputs with mismatching liveness, but not the
|
||||
// converse.
|
||||
//
|
||||
// REQUIRES: node is not a Merge operation.
|
||||
virtual bool HasInputsWithMismatchingDeadness(const Node& node) = 0;
|
||||
|
||||
// Prints out the internal state of this instance. For debugging purposes
|
||||
// only.
|
||||
virtual void Print() const = 0;
|
||||
virtual ~DeadnessAnalysis();
|
||||
|
||||
// Run the deadness analysis over `graph` and returns an error or a populated
|
||||
// instance of DeadnessAnalysis in `result`.
|
||||
static Status Run(const Graph& graph,
|
||||
std::unique_ptr<DeadnessAnalysis>* result);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_
|
443
tensorflow/compiler/jit/deadness_analysis_test.cc
Normal file
443
tensorflow/compiler/jit/deadness_analysis_test.cc
Normal file
@ -0,0 +1,443 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status AnalyzeDeadness(Graph* graph,
|
||||
std::unique_ptr<DeadnessAnalysis>* result) {
|
||||
FixupSourceAndSinkEdges(graph);
|
||||
return DeadnessAnalysis::Run(*graph, result);
|
||||
}
|
||||
|
||||
ops::Switch CreateSwitch(const Scope& root, const string& prefix) {
|
||||
Output value = ops::Placeholder(root.WithOpName(prefix + "/value"), DT_FLOAT);
|
||||
Output predicate =
|
||||
ops::Placeholder(root.WithOpName(prefix + "/pred"), DT_BOOL);
|
||||
return ops::Switch(root.WithOpName(prefix + "/switch"), value, predicate);
|
||||
}
|
||||
|
||||
Output CreateInductionVariable(const Scope& root, const string& prefix,
|
||||
const string& frame_name, int32 init) {
|
||||
Output initial_value = ops::Const(root.WithOpName(prefix + "/init"), init);
|
||||
Output enter_initial_value = ops::internal::Enter(
|
||||
root.WithOpName(prefix + "/enter"), initial_value, frame_name);
|
||||
|
||||
ops::Merge iv(root.WithOpName(prefix + "/iv"), {enter_initial_value});
|
||||
Output increment_by = ops::Const(root.WithOpName(prefix + "/incr"), 1);
|
||||
Output final_value = ops::Const(root.WithOpName(prefix + "/final"), 10);
|
||||
Output loop_cond_expr =
|
||||
ops::Less(root.WithOpName(prefix + "/less"), iv.output, final_value);
|
||||
Output loop_cond =
|
||||
ops::LoopCond(root.WithOpName(prefix + "/cond"), loop_cond_expr);
|
||||
ops::Switch latch(root.WithOpName(prefix + "/latch"), iv.output, loop_cond);
|
||||
ops::internal::Exit exit(root.WithOpName(prefix + "/exit"), iv.output);
|
||||
Output iv_next =
|
||||
ops::Add(root.WithOpName(prefix + "/ivnext"), iv.output, increment_by);
|
||||
Output next_iteration =
|
||||
ops::NextIteration(root.WithOpName(prefix + "next_iteration"), iv_next);
|
||||
|
||||
root.graph()->AddEdge(next_iteration.node(), 0, iv.output.node(), 1);
|
||||
root.graph()->AddControlEdge(iv.output.node(), increment_by.node());
|
||||
root.graph()->AddControlEdge(iv.output.node(), final_value.node());
|
||||
|
||||
return iv.output;
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, BasicPositive) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
Output add =
|
||||
ops::Add(root.WithOpName("add"), sw.output_true, sw.output_false);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, BasicNegative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output a = ops::Placeholder(root.WithOpName("a"), DT_FLOAT);
|
||||
Output b = ops::Placeholder(root.WithOpName("b"), DT_FLOAT);
|
||||
Output add = ops::Add(root.WithOpName("add"), a, b);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndIsCommutative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
|
||||
Output a0 =
|
||||
ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
|
||||
Output a1 =
|
||||
ops::Add(root.WithOpName("a1"), sw_1.output_false, sw_0.output_false);
|
||||
|
||||
Output b0 =
|
||||
ops::Add(root.WithOpName("b0"), sw_0.output_false, sw_1.output_true);
|
||||
Output b1 =
|
||||
ops::Add(root.WithOpName("b1"), sw_1.output_true, sw_0.output_false);
|
||||
|
||||
Output live0 = ops::Add(root.WithOpName("live0"), a0, a1);
|
||||
Output live1 = ops::Add(root.WithOpName("live1"), b0, b1);
|
||||
|
||||
Output halfdead0 = ops::Add(root.WithOpName("halfdead0"), a0, b0);
|
||||
Output halfdead1 = ops::Add(root.WithOpName("halfdead1"), a1, b1);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndIsAssociative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
|
||||
Output a0 =
|
||||
ops::Add(root.WithOpName("a0"), sw_0.output_false, sw_1.output_false);
|
||||
Output a1 = ops::Add(root.WithOpName("a1"), a0, sw_2.output_false);
|
||||
|
||||
Output b0 =
|
||||
ops::Add(root.WithOpName("b0"), sw_1.output_false, sw_2.output_false);
|
||||
Output b1 = ops::Add(root.WithOpName("b1"), sw_0.output_false, b0);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), a1, b1);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrIsCommutative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
ops::Merge m1(root.WithOpName("m1"), {sw_1.output_false, sw_0.output_false});
|
||||
ops::Merge m2(root.WithOpName("m2"), {sw_0.output_false, sw_1.output_true});
|
||||
ops::Merge m3(root.WithOpName("m3"), {sw_1.output_true, sw_0.output_false});
|
||||
|
||||
Output live0 = ops::Add(root.WithOpName("live0"), m0.output, m1.output);
|
||||
Output live1 = ops::Add(root.WithOpName("live1"), m2.output, m3.output);
|
||||
|
||||
Output halfdead0 =
|
||||
ops::Add(root.WithOpName("halfdead0"), m0.output, m2.output);
|
||||
Output halfdead1 =
|
||||
ops::Add(root.WithOpName("halfdead1"), m1.output, m3.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live0.node()));
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*live1.node()));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*halfdead1.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrIsAssociative) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
ops::Merge m1(root.WithOpName("m1"), {m0.output, sw_2.output_false});
|
||||
ops::Merge m2(root.WithOpName("m2"), {sw_1.output_false, sw_2.output_false});
|
||||
ops::Merge m3(root.WithOpName("m3"), {sw_0.output_false, m2.output});
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), m1.output, m3.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndOfOr) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
ops::Switch sw_3 = CreateSwitch(root, "3");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
ops::Merge m1(root.WithOpName("m1"), {sw_2.output_false, sw_3.output_false});
|
||||
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), m0.output, m1.output);
|
||||
Output add1 = ops::Add(root.WithOpName("add1"), m0.output, m1.output);
|
||||
|
||||
Output add2 = ops::Add(root.WithOpName("add2"), add0, add1);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, OrOfAnd) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
ops::Switch sw_3 = CreateSwitch(root, "3");
|
||||
|
||||
Output add0 =
|
||||
ops::Add(root.WithOpName("add0"), sw_0.output_false, sw_1.output_false);
|
||||
Output add1 =
|
||||
ops::Add(root.WithOpName("add1"), sw_2.output_false, sw_3.output_false);
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {add0, add1});
|
||||
ops::Merge m1(root.WithOpName("m1"), {add0, add1});
|
||||
|
||||
Output add2 = ops::Add(root.WithOpName("add2"), m0.output, m1.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
|
||||
// This demonstrates one of the weaknesses in the current approach -- since we
|
||||
// only do some basic simplifications we can't see that "(A|B)&C" ==
|
||||
// "(A&C)|(B&C)".
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
ops::Switch sw_0 = CreateSwitch(root, "0");
|
||||
ops::Switch sw_1 = CreateSwitch(root, "1");
|
||||
ops::Switch sw_2 = CreateSwitch(root, "2");
|
||||
|
||||
ops::Merge m0(root.WithOpName("m0"), {sw_0.output_false, sw_1.output_false});
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), m0.output, sw_2.output_false);
|
||||
|
||||
Output add1 =
|
||||
ops::Add(root.WithOpName("add1"), sw_0.output_false, sw_2.output_false);
|
||||
Output add2 =
|
||||
ops::Add(root.WithOpName("add2"), sw_1.output_false, sw_2.output_false);
|
||||
ops::Merge m1(root.WithOpName("m1"), {add1, add2});
|
||||
|
||||
Output add3 = ops::Add(root.WithOpName("add3"), add0, m1.output);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Ternary) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output predicate = ops::Placeholder(root.WithOpName("predicate"), DT_BOOL);
|
||||
Output true_value = ops::Placeholder(root.WithOpName("true_value"), DT_FLOAT);
|
||||
Output false_value =
|
||||
ops::Placeholder(root.WithOpName("false_value"), DT_FLOAT);
|
||||
|
||||
ops::Switch predicated_true(root.WithOpName("predicated_true"), true_value,
|
||||
predicate);
|
||||
|
||||
ops::Switch predicated_false(root.WithOpName("predicated_false"), true_value,
|
||||
predicate);
|
||||
ops::Merge merge(root.WithOpName("ternary"), {predicated_true.output_true,
|
||||
predicated_false.output_false});
|
||||
Output addend = ops::Placeholder(root.WithOpName("addend"), DT_FLOAT);
|
||||
Output add = ops::Add(root.WithOpName("add"), merge.output, addend);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Recv) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_FLOAT, "tensor_a",
|
||||
"sender", 0, "receiver");
|
||||
Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_FLOAT, "tensor_b",
|
||||
"sender", 0, "receiver");
|
||||
Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, HostRecv) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv_a = ops::_HostRecv(root.WithOpName("recv_a"), DT_FLOAT,
|
||||
"tensor_a", "sender", 0, "receiver");
|
||||
Output recv_b = ops::_HostRecv(root.WithOpName("recv_b"), DT_FLOAT,
|
||||
"tensor_b", "sender", 0, "receiver");
|
||||
Output add = ops::Add(root.WithOpName("add"), recv_a, recv_b);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, Loop) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output iv0 = CreateInductionVariable(root, "iv0", "fr0", 0);
|
||||
Output iv1 = CreateInductionVariable(root, "iv1", "fr0", 0);
|
||||
Output iv2 = CreateInductionVariable(root, "iv2", "fr0", 1);
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), iv0, iv1);
|
||||
Output add1 = ops::Add(root.WithOpName("add1"), iv1, iv2);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
// NB! iv0 and iv1 are equivalent and a smarter deadness analysis would have
|
||||
// noticed that. Today we are pessimistic here because we assign an
|
||||
// uninterpreted symbol to merges with backedges.
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add0.node()));
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add1.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlInputs) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
|
||||
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
|
||||
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
|
||||
|
||||
Output const0 = ops::Const(root.WithOpName("const0"), 1);
|
||||
Output const1 = ops::Const(root.WithOpName("const1"), 2);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), const0, const1);
|
||||
|
||||
root.graph()->AddControlEdge(id0.node(), const0.node());
|
||||
root.graph()->AddControlEdge(id1.node(), const1.node());
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlTrigger) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
|
||||
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
|
||||
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
|
||||
|
||||
ops::ControlTrigger ctrl_trigger0(root.WithOpName("ctrl_trigger0"));
|
||||
ops::ControlTrigger ctrl_trigger1(root.WithOpName("ctrl_trigger1"));
|
||||
|
||||
Output const0 = ops::Const(root.WithOpName("const0"), 1);
|
||||
Output const1 = ops::Const(root.WithOpName("const1"), 2);
|
||||
|
||||
Output add = ops::Add(root.WithOpName("add"), const0, const1);
|
||||
|
||||
root.graph()->AddControlEdge(id0.node(), ctrl_trigger0.operation.node());
|
||||
root.graph()->AddControlEdge(ctrl_trigger0.operation.node(), const0.node());
|
||||
|
||||
root.graph()->AddControlEdge(id1.node(), ctrl_trigger1.operation.node());
|
||||
root.graph()->AddControlEdge(ctrl_trigger1.operation.node(), const1.node());
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, ControlInputsToMerge) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
ops::Switch sw = CreateSwitch(root, "0");
|
||||
|
||||
Output id0 = ops::Identity(root.WithOpName("id0"), sw.output_false);
|
||||
Output id1 = ops::Identity(root.WithOpName("id1"), sw.output_true);
|
||||
|
||||
Output constant = ops::Const(root.WithOpName("constant"), 5);
|
||||
ops::Merge m0(root.WithOpName("m0"), {constant});
|
||||
ops::Merge m1(root.WithOpName("m0"), {constant});
|
||||
Output add = ops::Add(root.WithOpName("add"), m0.output, m1.output);
|
||||
|
||||
root.graph()->AddControlEdge(id0.node(), m0.output.node());
|
||||
root.graph()->AddControlEdge(id1.node(), m1.output.node());
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add.node()));
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, RecvVsSwitch) {
|
||||
// Demonstrates why we need the must_be_true bit on SymbolP.
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv = ops::_Recv(root.WithOpName("recv"), DT_BOOL, "tensor", "sender",
|
||||
0, "receiver");
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_BOOL);
|
||||
ops::Switch sw(root.WithOpName("switch"), value, recv);
|
||||
Output logical_and =
|
||||
ops::LogicalAnd(root.WithOpName("and"), recv, sw.output_true);
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*logical_and.node()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -60,9 +60,9 @@ const char* const kXlaHostTransferSequencerAttr =
|
||||
|
||||
namespace {
|
||||
|
||||
bool AreAllParentsConst(const Node& n,
|
||||
const gtl::FlatSet<const Node*>& runtime_const_nodes) {
|
||||
if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
|
||||
bool AreAllParentsGuaranteedConst(
|
||||
const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
|
||||
if (n.type_string() == "GuaranteeConst") {
|
||||
// If the current node is itself a cast-to-const, no need
|
||||
// to look at the incoming edges.
|
||||
return true;
|
||||
@ -93,7 +93,8 @@ void MarkGuaranteedConstants(
|
||||
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
|
||||
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
|
||||
// TODO(vinuraja): Doesn't work in the presence of loops.
|
||||
if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
|
||||
if (AreAllParentsGuaranteedConst(*n,
|
||||
guaranteed_const_nodes)) {
|
||||
guaranteed_const_nodes.insert(n);
|
||||
}
|
||||
});
|
||||
@ -137,7 +138,7 @@ class Encapsulator {
|
||||
|
||||
// Find subgraphs marked with 'group_attribute', and build a new
|
||||
// subgraph, one for each value of 'group_attribute'.
|
||||
Status SplitIntoSubgraphs();
|
||||
Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
|
||||
|
||||
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
|
||||
// the 'group_attribute' annotations become the function names.
|
||||
@ -1136,7 +1137,10 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo(
|
||||
GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
|
||||
host_compute->AddAttr("shape_inference_graph", inference_graph_name);
|
||||
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||
// TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator.
|
||||
if (library->Find(inference_graph_name) == nullptr) {
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1474,7 +1478,7 @@ Status Encapsulator::CopySubgraphEdges(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::SplitIntoSubgraphs() {
|
||||
Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
|
||||
Status s;
|
||||
|
||||
// Map from input graph nodes to subgraph nodes.
|
||||
@ -1509,6 +1513,15 @@ Status Encapsulator::SplitIntoSubgraphs() {
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
// Dump subgraphs.
|
||||
for (auto& entry : subgraphs_) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
|
||||
*entry.second.GetGraph(), library);
|
||||
}
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -1932,6 +1945,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
// continue.
|
||||
TensorShapeProto proto;
|
||||
context->ShapeHandleToProto(shape, &proto);
|
||||
VLOG(2) << "Node " << src_node->name()
|
||||
<< " has known shape: " << proto.DebugString();
|
||||
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
|
||||
dummy_node_images[src_node] =
|
||||
AddDummyShapedNode(src_node, src_port, control_flow_info,
|
||||
@ -1949,6 +1964,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
if (VLOG_IS_ON(2)) {
|
||||
TensorShapeProto proto;
|
||||
context->ShapeHandleToProto(shape, &proto);
|
||||
VLOG(2) << "Node " << src_node->name()
|
||||
<< " has unknown shape: " << proto.DebugString();
|
||||
}
|
||||
stack.push_back({src_node, false});
|
||||
}
|
||||
@ -2191,6 +2208,23 @@ Status Encapsulator::FindClusterDependencies() {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (VLOG_IS_ON(2)) {
|
||||
// Print debug information.
|
||||
VLOG(2) << "node_ancestors_map:";
|
||||
for (const auto& node_iter : node_ancestors_map) {
|
||||
VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '"
|
||||
<< node_iter.second.subgraph
|
||||
<< "', outside_compilation_cluster = '"
|
||||
<< node_iter.second.outside_compilation_cluster
|
||||
<< "', ancestor_clusters: "
|
||||
<< (node_iter.second.ancestor_clusters.empty() ? "(empty)" : "");
|
||||
for (const auto& cluster_iter : node_iter.second.ancestor_clusters) {
|
||||
VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph
|
||||
<< "', outside_compilation_cluster = '"
|
||||
<< cluster_iter.outside_compilation_cluster << "'";
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2398,7 +2432,7 @@ Status EncapsulateSubgraphsInFunctions(
|
||||
std::move(outside_compilation_attribute),
|
||||
&graph_in);
|
||||
TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies());
|
||||
TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
|
||||
TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
|
||||
|
||||
TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
|
||||
rewrite_subgraph_fn, reuse_existing_functions, library));
|
||||
@ -2447,7 +2481,7 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
VLOG(1) << "EncapsulateSubgraphsPass::Run";
|
||||
if (VLOG_IS_ON(1)) {
|
||||
dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
|
||||
dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
@ -2530,7 +2564,7 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
"EncapsulateSubgraphsPass failed");
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
|
||||
dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
|
||||
options.flib_def);
|
||||
}
|
||||
|
||||
|
@ -742,10 +742,13 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
|
||||
"/job:localhost/replica:0/task:0/cpu:0");
|
||||
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
|
||||
auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f);
|
||||
auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
|
||||
auto const_guarantee_x2 =
|
||||
ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
|
||||
auto const_guarantee_x1 =
|
||||
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
|
||||
auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2);
|
||||
auto add1 =
|
||||
ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2);
|
||||
add1.node()->AddAttr("_encapsulate", "encapsulate1");
|
||||
|
||||
Graph graph_before(OpRegistry::Global());
|
||||
|
@ -51,7 +51,11 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
if (device_type_ == DeviceType(DEVICE_CPU)) {
|
||||
platform_id_ = se::host::kHostPlatformId;
|
||||
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
|
||||
platform_id_ = se::cuda::kCudaPlatformId;
|
||||
platform_id_ = ctx->device()
|
||||
->tensorflow_gpu_device_info()
|
||||
->stream->parent()
|
||||
->platform()
|
||||
->id();
|
||||
} else {
|
||||
platform_id_ = nullptr;
|
||||
}
|
||||
@ -115,6 +119,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
const XlaDevice::Metadata* metadata = nullptr;
|
||||
Status s = XlaDevice::GetMetadata(ctx, &metadata);
|
||||
bool allocate_xla_tensors = s.ok();
|
||||
bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
|
||||
|
||||
// Get the platform_id_ for XLA_* devices.
|
||||
if (platform_id_ == nullptr) {
|
||||
@ -180,8 +185,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
XlaComputationLaunchContext launch_context(client, xla_allocator,
|
||||
allocate_xla_tensors);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
|
||||
launch_context.PopulateInputs(ctx, kernel, variables);
|
||||
|
||||
// Execute the computation.
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
@ -462,18 +464,19 @@ Status MarkForCompilationPass::Run(
|
||||
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
|
||||
const FunctionLibraryDefinition* fld = options.flib_def;
|
||||
|
||||
auto is_compilable = [global_jit_level, cpu_global_jit, fusion_only, fld](
|
||||
const Node* node, const DeviceType& device_type) {
|
||||
std::unique_ptr<DeadnessAnalysis> deadness;
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("DeadnessAnalysis", 1);
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(**options.graph, &deadness));
|
||||
}
|
||||
|
||||
auto is_compilable = [&](const Node* node, const DeviceType& device_type) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
|
||||
®istration)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't compile control trigger nodes. We won't preserve their deadness
|
||||
// semantics correctly, so it's safest not to compile them.
|
||||
if (node->IsControlTrigger()) return false;
|
||||
|
||||
// If this device requires a JIT, we must say yes.
|
||||
if (registration->requires_compilation) return true;
|
||||
|
||||
@ -485,6 +488,14 @@ Status MarkForCompilationPass::Run(
|
||||
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
|
||||
if (status.ok()) return compile;
|
||||
|
||||
// If inputs to `node` can have conflicting deadness (i.e. some are alive
|
||||
// and some are dead) then don't compile it. XLA cannot represent the
|
||||
// deadness semantics of these nodes correctly and auto-clustering these
|
||||
// nodes can cause deadness to propagate to nodes that should be live.
|
||||
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for fusable ops only if requested.
|
||||
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
|
||||
return false;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
@ -680,5 +681,37 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
|
||||
EXPECT_EQ(clusters, expected_clusters);
|
||||
}
|
||||
|
||||
TEST(XlaCompilationTest, ClusterControlTrigger) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a",
|
||||
"sender", 0, "receiver");
|
||||
Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b",
|
||||
"sender", 0, "receiver");
|
||||
Output const_a = ops::Const(root.WithOpName("const_a"), 42);
|
||||
|
||||
ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a"));
|
||||
ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b"));
|
||||
root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node());
|
||||
root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node());
|
||||
root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node());
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
TF_ASSERT_OK(MarkForCompilation(&graph));
|
||||
|
||||
std::unordered_map<string, string> clusters = GetClusters(*graph);
|
||||
|
||||
ASSERT_FALSE(clusters.empty());
|
||||
string cluster_name = clusters.begin()->second;
|
||||
|
||||
// ctrl_trigger_a has inputs with mismatching deadness so it won't be
|
||||
// clustered. ctrl_trigger_b is okay to cluster.
|
||||
std::unordered_map<string, string> expected_clusters(
|
||||
{{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}});
|
||||
EXPECT_EQ(clusters, expected_clusters);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, client->backend().memory_allocator(), true);
|
||||
client, client->backend().memory_allocator(),
|
||||
/*allocate_xla_tensors=*/true,
|
||||
/*use_multiple_streams=*/metadata.UseMultipleStreams());
|
||||
|
||||
launch_context.PopulateInputs(ctx, result, variables);
|
||||
|
||||
|
@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
DEVICE_CPU_XLA_JIT, options, name_prefix,
|
||||
registration,
|
||||
/*transfer_as_literal=*/false,
|
||||
/*use_multiple_streams=*/false,
|
||||
/*shape_representation_fn=*/{},
|
||||
/*padded_shape_fn=*/{}, &device));
|
||||
devices->push_back(device.release());
|
||||
|
@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
||||
const string& jit_device_name, const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
const XlaOpRegistry::DeviceRegistration& registration,
|
||||
bool transfer_as_literal,
|
||||
bool transfer_as_literal, bool use_multiple_streams,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
|
||||
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
||||
@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
|
||||
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
||||
strings::StrCat("device: ", device_name, " device"));
|
||||
|
||||
device->reset(new XlaDevice(
|
||||
options, attrs, device_ordinal, DeviceType(jit_device_name),
|
||||
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
|
||||
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
|
||||
device->reset(
|
||||
new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
|
||||
platform.ValueOrDie(), transfer_as_literal,
|
||||
use_multiple_streams, shape_representation_fn,
|
||||
padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaDevice::Metadata::Metadata(
|
||||
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
PaddedShapeFn padded_shape_fn)
|
||||
PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
|
||||
: device_ordinal_(device_ordinal),
|
||||
device_type_(device_type),
|
||||
platform_(platform),
|
||||
shape_representation_fn_(std::move(shape_representation_fn)),
|
||||
padded_shape_fn_(std::move(padded_shape_fn)) {}
|
||||
padded_shape_fn_(std::move(padded_shape_fn)),
|
||||
use_multiple_streams_(use_multiple_streams) {}
|
||||
|
||||
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
|
||||
XlaDevice::XlaDevice(
|
||||
const SessionOptions& options, const DeviceAttributes& attrs,
|
||||
int device_ordinal, const DeviceType& jit_device_name,
|
||||
se::Platform* platform, bool transfer_as_literal,
|
||||
se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
const PaddedShapeFn& padded_shape_fn)
|
||||
: LocalDevice(options, attrs),
|
||||
xla_metadata_(device_ordinal, platform, jit_device_name,
|
||||
shape_representation_fn, padded_shape_fn),
|
||||
shape_representation_fn, padded_shape_fn,
|
||||
use_multiple_streams),
|
||||
device_ordinal_(device_ordinal),
|
||||
jit_device_name_(jit_device_name),
|
||||
xla_allocator_(nullptr),
|
||||
platform_(platform),
|
||||
use_multiple_streams_(use_multiple_streams),
|
||||
transfer_as_literal_(transfer_as_literal),
|
||||
shape_representation_fn_(shape_representation_fn) {
|
||||
VLOG(1) << "Created XLA device " << jit_device_name;
|
||||
@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
|
||||
return stream_.get();
|
||||
}
|
||||
|
||||
xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
|
||||
if (!use_multiple_streams_) {
|
||||
return GetStream();
|
||||
}
|
||||
if (!device_to_host_stream_) {
|
||||
xla::Backend* backend = client()->mutable_backend();
|
||||
TF_ASSIGN_OR_RETURN(device_to_host_stream_,
|
||||
backend->BorrowStream(device_ordinal_));
|
||||
}
|
||||
return device_to_host_stream_.get();
|
||||
}
|
||||
|
||||
xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
|
||||
if (!use_multiple_streams_) {
|
||||
return GetStream();
|
||||
}
|
||||
if (!host_to_device_stream_) {
|
||||
xla::Backend* backend = client()->mutable_backend();
|
||||
TF_ASSIGN_OR_RETURN(host_to_device_stream_,
|
||||
backend->BorrowStream(device_ordinal_));
|
||||
}
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
|
||||
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
||||
if (gpu_device_info_ == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
||||
// gpu_device_info_->default_context.
|
||||
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
||||
gpu_device_info_->stream = stream;
|
||||
gpu_device_info_->default_context = new XlaDeviceContext(
|
||||
stream, client(), transfer_as_literal_, shape_representation_fn_);
|
||||
gpu_device_info_->default_context =
|
||||
new XlaDeviceContext(stream, stream, stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
||||
}
|
||||
|
||||
@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
|
||||
VLOG(1) << "XlaDevice::FillContextMap";
|
||||
device_context_map->resize(graph->num_node_ids());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
|
||||
GetDeviceToHostStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
|
||||
GetHostToDeviceStream());
|
||||
|
||||
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
||||
GetAllocator({});
|
||||
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
|
||||
shape_representation_fn_);
|
||||
auto ctx = new XlaDeviceContext(
|
||||
stream, host_to_device_stream, device_to_host_stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
for (Node* n : graph->nodes()) {
|
||||
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
||||
ctx->Ref();
|
||||
@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||
Notification n;
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
XlaTransferManager manager(stream, client(), transfer_as_literal_,
|
||||
shape_representation_fn_);
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
|
||||
GetDeviceToHostStream());
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
|
||||
GetHostToDeviceStream());
|
||||
XlaTransferManager manager(stream, host_to_device_stream,
|
||||
device_to_host_stream, client(),
|
||||
transfer_as_literal_, shape_representation_fn_);
|
||||
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
|
@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice {
|
||||
Metadata(int device_ordinal, se::Platform* platform,
|
||||
const DeviceType& device_type,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
PaddedShapeFn padded_shape_fn);
|
||||
PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
|
||||
|
||||
// The index of the device on this host.
|
||||
int device_ordinal() const;
|
||||
@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice {
|
||||
}
|
||||
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
|
||||
|
||||
bool UseMultipleStreams() const { return use_multiple_streams_; }
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
const DeviceType device_type_;
|
||||
se::Platform* platform_; // Not owned.
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
PaddedShapeFn padded_shape_fn_;
|
||||
const bool use_multiple_streams_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
|
||||
};
|
||||
@ -89,6 +92,8 @@ class XlaDevice : public LocalDevice {
|
||||
// 'transfer_as_literal' is true if device<->host transfers must be done using
|
||||
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
|
||||
// ThenMemcpy instead.
|
||||
// If 'use_multiple_streams' is true, we create separate streams for
|
||||
// host-to-device and device-to-host communication.
|
||||
// If padded_shape_fn is empty, a default implementation that returns
|
||||
// the on-host shape is used.
|
||||
static Status Create(
|
||||
@ -96,7 +101,7 @@ class XlaDevice : public LocalDevice {
|
||||
int device_ordinal, const string& jit_device_name,
|
||||
const SessionOptions& options, const string& name_prefix,
|
||||
const XlaOpRegistry::DeviceRegistration& registration,
|
||||
bool transfer_as_literal,
|
||||
bool transfer_as_literal, bool use_multiple_streams,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
|
||||
|
||||
@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice {
|
||||
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
||||
int device_ordinal, const DeviceType& jit_device_name,
|
||||
se::Platform* platform, bool transfer_as_literal,
|
||||
bool use_multiple_streams,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
const PaddedShapeFn& padded_shape_fn);
|
||||
~XlaDevice() override;
|
||||
@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice {
|
||||
xla::LocalClient* client() const;
|
||||
const Metadata& metadata() { return xla_metadata_; }
|
||||
xla::StatusOr<se::Stream*> GetStream();
|
||||
xla::StatusOr<se::Stream*> GetHostToDeviceStream();
|
||||
xla::StatusOr<se::Stream*> GetDeviceToHostStream();
|
||||
|
||||
// If not already set, create and set GpuDeviceInfo.
|
||||
// Not thread-safe
|
||||
@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice {
|
||||
// copying back and forth between CPU and the device, and
|
||||
// computations enqueued by XLA.
|
||||
xla::Backend::StreamPtr stream_;
|
||||
// If true, only stream_ is valid and all computation and transfers use
|
||||
// stream_. If false, computation is performed by stream_ and transfers are
|
||||
// performed by host_to_device/device_to_host_stream.
|
||||
bool use_multiple_streams_;
|
||||
// If use_multiple_streams_, host to device transfers are performed using this
|
||||
// stream.
|
||||
xla::Backend::StreamPtr host_to_device_stream_;
|
||||
// If use_multiple_streams_, device to host transfers are performed using this
|
||||
// stream.
|
||||
xla::Backend::StreamPtr device_to_host_stream_;
|
||||
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
||||
// false, we can use ThenMemcpy() instead.
|
||||
bool transfer_as_literal_;
|
||||
|
@ -48,17 +48,24 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
|
||||
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||
|
||||
XlaTransferManager::XlaTransferManager(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
: stream_(stream),
|
||||
: stream_(compute_stream),
|
||||
host_to_device_stream_(host_to_device_stream),
|
||||
device_to_host_stream_(device_to_host_stream),
|
||||
client_(client),
|
||||
transfer_manager_(client->backend().transfer_manager()),
|
||||
transfer_as_literal_(transfer_as_literal),
|
||||
shape_representation_fn_(std::move(shape_representation_fn)) {
|
||||
CHECK(host_to_device_stream_ != nullptr);
|
||||
CHECK(device_to_host_stream_ != nullptr);
|
||||
CHECK(stream_ != nullptr);
|
||||
if (!shape_representation_fn_) {
|
||||
shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
|
||||
return shape;
|
||||
};
|
||||
shape_representation_fn_ =
|
||||
[](const TensorShape& shape,
|
||||
DataType dtype) -> xla::StatusOr<TensorShape> { return shape; };
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,15 +81,26 @@ Status XlaTransferManager::TransferLiteralToDevice(
|
||||
auto literal = std::make_shared<xla::BorrowingLiteral>(
|
||||
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
|
||||
|
||||
const xla::ShapedBuffer& shaped_buffer =
|
||||
XlaTensor::FromTensor(device_tensor)->shaped_buffer();
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
|
||||
VLOG(1) << "Transfer to device as literal: " << literal->ToString() << " "
|
||||
<< shaped_buffer.ToString();
|
||||
if (UseMultipleStreams()) {
|
||||
// Initially wait for the compute stream so that memory allocations are
|
||||
// synchronized.
|
||||
host_to_device_stream_->ThenWaitFor(stream_);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
|
||||
stream_, *literal, shaped_buffer));
|
||||
host_to_device_stream_, *literal, shaped_buffer));
|
||||
if (UseMultipleStreams()) {
|
||||
se::Event event(stream_->parent());
|
||||
TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
|
||||
host_to_device_stream_->ThenRecordEvent(&event);
|
||||
xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
|
||||
}
|
||||
// Unref the host tensor, and capture the literal shared_ptr too so it goes
|
||||
// out of scope when the lambda completes.
|
||||
stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
|
||||
host_to_device_stream_->ThenDoHostCallback([ref, literal]() { ref.Unref(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -94,7 +112,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
|
||||
|
||||
TensorReference ref(device_tensor);
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
stream_, shaped_buffer,
|
||||
device_to_host_stream_, shaped_buffer,
|
||||
[=, &shaped_buffer](
|
||||
xla::StatusOr<std::unique_ptr<xla::Literal> > literal_or) {
|
||||
ref.Unref();
|
||||
@ -120,62 +138,73 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
Tensor* device_tensor,
|
||||
StatusCallback done) const {
|
||||
if (cpu_tensor->NumElements() > 0) {
|
||||
VLOG(2) << "CopyCPUTensorToDevice "
|
||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||
<< " "
|
||||
<< reinterpret_cast<const void*>(
|
||||
device_tensor->tensor_data().data())
|
||||
<< " " << cpu_tensor->NumElements() << " "
|
||||
<< cpu_tensor->shape().DebugString() << " "
|
||||
<< device_tensor->shape().DebugString();
|
||||
|
||||
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
CHECK(xla_tensor);
|
||||
|
||||
TensorShape shape = shape_representation_fn_(device_tensor->shape(),
|
||||
device_tensor->dtype());
|
||||
Status status;
|
||||
if (!xla_tensor->has_shaped_buffer()) {
|
||||
status = xla_tensor->AllocateShapedBuffer(
|
||||
device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal());
|
||||
if (!status.ok()) {
|
||||
return done(status);
|
||||
}
|
||||
}
|
||||
|
||||
if (transfer_as_literal_) {
|
||||
Tensor reshaped_cpu_tensor;
|
||||
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
|
||||
done(errors::Internal(
|
||||
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
||||
return;
|
||||
}
|
||||
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
|
||||
} else {
|
||||
se::DeviceMemoryBase dev_dst_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = stream_->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
status = xla::InternalError(
|
||||
"Failed to complete data transfer on stream %p: %s", stream_,
|
||||
block_status.error_message().c_str());
|
||||
}
|
||||
}
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
|
||||
done(status);
|
||||
if (cpu_tensor->NumElements() == 0) {
|
||||
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
|
||||
VLOG(2) << "CopyCPUTensorToDevice empty tensor";
|
||||
done(Status::OK());
|
||||
VLOG(2) << "CopyCPUTensorToDevice "
|
||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||
<< " "
|
||||
<< reinterpret_cast<const void*>(device_tensor->tensor_data().data())
|
||||
<< " " << cpu_tensor->NumElements() << " "
|
||||
<< cpu_tensor->shape().DebugString() << " "
|
||||
<< device_tensor->shape().DebugString();
|
||||
|
||||
void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
CHECK(xla_tensor);
|
||||
|
||||
xla::StatusOr<TensorShape> shape_or_status =
|
||||
shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
|
||||
if (!shape_or_status.ok()) {
|
||||
done(shape_or_status.status());
|
||||
return;
|
||||
}
|
||||
TensorShape shape = shape_or_status.ValueOrDie();
|
||||
if (!xla_tensor->has_shaped_buffer()) {
|
||||
Status s =
|
||||
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal());
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
Status status;
|
||||
if (transfer_as_literal_) {
|
||||
Tensor reshaped_cpu_tensor;
|
||||
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
|
||||
done(errors::Internal(
|
||||
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
||||
return;
|
||||
}
|
||||
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
|
||||
if (status.ok()) {
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
host_to_device_stream_->ThenDoHostCallback(
|
||||
[done]() { done(Status::OK()); });
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
se::DeviceMemoryBase dev_dst_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = host_to_device_stream_->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
status = xla::InternalError(
|
||||
"Failed to complete data transfer on stream %p: %s",
|
||||
host_to_device_stream_, block_status.error_message().c_str());
|
||||
}
|
||||
}
|
||||
xla_tensor->set_host_tensor(*cpu_tensor);
|
||||
|
||||
done(status);
|
||||
}
|
||||
|
||||
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
@ -183,68 +212,102 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
Device* device,
|
||||
Tensor* cpu_tensor,
|
||||
StatusCallback done) {
|
||||
if (device_tensor->NumElements() > 0) {
|
||||
VLOG(2) << "CopyDeviceTensorToCPU "
|
||||
<< reinterpret_cast<const void*>(
|
||||
device_tensor->tensor_data().data())
|
||||
<< " "
|
||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||
<< " " << device_tensor->NumElements() << " "
|
||||
<< cpu_tensor->shape().DebugString() << " "
|
||||
<< device_tensor->shape().DebugString();
|
||||
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
se::DeviceMemoryBase dev_src_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
|
||||
Status status;
|
||||
if (transfer_as_literal_) {
|
||||
TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
|
||||
return;
|
||||
} else {
|
||||
stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = stream_->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
status = xla::InternalError(
|
||||
"Failed to complete data transfer on stream %p: %s", stream_,
|
||||
block_status.error_message().c_str());
|
||||
}
|
||||
done(status);
|
||||
}
|
||||
if (device_tensor->NumElements() == 0) {
|
||||
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
VLOG(2) << "CopyDeviceTensorToCPU "
|
||||
<< reinterpret_cast<const void*>(device_tensor->tensor_data().data())
|
||||
<< " "
|
||||
<< reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
|
||||
<< " " << device_tensor->NumElements() << " "
|
||||
<< cpu_tensor->shape().DebugString() << " "
|
||||
<< device_tensor->shape().DebugString();
|
||||
|
||||
VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
|
||||
done(Status::OK());
|
||||
const int64 total_bytes = cpu_tensor->TotalBytes();
|
||||
se::DeviceMemoryBase dev_src_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
void* dst_ptr = DMAHelper::base(cpu_tensor);
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
|
||||
if (se::Event* event =
|
||||
xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
|
||||
device_to_host_stream_->ThenWaitFor(event);
|
||||
xla_tensor->SetDefinedOn(device_to_host_stream_);
|
||||
}
|
||||
|
||||
Status status;
|
||||
if (transfer_as_literal_) {
|
||||
TransferLiteralFromDevice(cpu_tensor, *device_tensor, done);
|
||||
return;
|
||||
} else {
|
||||
device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = device_to_host_stream_->BlockHostUntilDone();
|
||||
if (!block_status.ok()) {
|
||||
status = xla::InternalError(
|
||||
"Failed to complete data transfer on stream %p: %s", stream_,
|
||||
block_status.error_message().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
done(status);
|
||||
}
|
||||
|
||||
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
|
||||
Tensor* dst_tensor,
|
||||
const StatusCallback& done) {
|
||||
VLOG(2) << "CopyDeviceTensorToDevice "
|
||||
<< reinterpret_cast<const void*>(src_tensor.tensor_data().data())
|
||||
<< " "
|
||||
<< reinterpret_cast<const void*>(dst_tensor->tensor_data().data());
|
||||
// Perform memory allocation now, and enqueue the device-to-device transfer.
|
||||
Status status = [&]() -> Status {
|
||||
if (src_tensor.NumElements() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
// TODO(jmolloy): We co-opt the device_to_host stream for device to device
|
||||
// transfers; perhaps we should have a dedicated device to device stream? or
|
||||
// one per device?
|
||||
auto device_to_device_stream = stream_;
|
||||
XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
|
||||
XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
|
||||
CHECK(xla_src && xla_dst)
|
||||
<< "Missing destination tensor for device-to-device copy";
|
||||
if (!xla_dst->has_shaped_buffer()) {
|
||||
TensorShape shape =
|
||||
shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
TensorShape shape,
|
||||
shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal()));
|
||||
if (stream_ != device_to_device_stream) {
|
||||
// Initially wait for the compute stream so that memory allocations are
|
||||
// synchronized.
|
||||
device_to_device_stream->ThenWaitFor(stream_);
|
||||
}
|
||||
}
|
||||
|
||||
if (se::Event* event =
|
||||
xla_src->GetDefinitionEvent(device_to_device_stream)) {
|
||||
device_to_device_stream->ThenWaitFor(event);
|
||||
xla_src->SetDefinedOn(device_to_device_stream);
|
||||
}
|
||||
|
||||
auto from_iter = xla_src->shaped_buffer().buffers().begin();
|
||||
auto to_iter = xla_dst->shaped_buffer().buffers().begin();
|
||||
for (auto end_iter = xla_src->shaped_buffer().buffers().end();
|
||||
from_iter != end_iter; ++from_iter, ++to_iter) {
|
||||
stream_->ThenMemcpyD2D(&to_iter->second, from_iter->second,
|
||||
to_iter->second.size());
|
||||
device_to_device_stream->ThenMemcpyD2D(
|
||||
&to_iter->second, from_iter->second, to_iter->second.size());
|
||||
}
|
||||
|
||||
if (UseMultipleStreams()) {
|
||||
se::Event event(stream_->parent());
|
||||
CHECK(event.Init());
|
||||
device_to_device_stream->ThenRecordEvent(&event);
|
||||
xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
|
||||
}
|
||||
return Status::OK();
|
||||
}();
|
||||
@ -256,9 +319,12 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
|
||||
}
|
||||
|
||||
XlaDeviceContext::XlaDeviceContext(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
: manager_(stream, client, transfer_as_literal,
|
||||
: manager_(compute_stream, host_to_device_stream, device_to_host_stream,
|
||||
client, transfer_as_literal,
|
||||
std::move(shape_representation_fn)) {}
|
||||
|
||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
|
@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator {
|
||||
class XlaTransferManager {
|
||||
public:
|
||||
explicit XlaTransferManager(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
@ -67,10 +69,17 @@ class XlaTransferManager {
|
||||
void TransferLiteralFromDevice(Tensor* host_tensor,
|
||||
const Tensor& device_tensor,
|
||||
const StatusCallback& done) const;
|
||||
bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
|
||||
|
||||
// Stream obtained from a Device, used to transfer tensors between
|
||||
// CPU and device.
|
||||
// The main compute stream of the device, used to synchronize the transfer
|
||||
// streams if they are set.
|
||||
se::Stream* stream_;
|
||||
// The stream to use for transferring data from host to device. Can be
|
||||
// idential to stream_, but must not be nullptr.
|
||||
se::Stream* host_to_device_stream_;
|
||||
// The stream to use for transferring data from device to host. Can be
|
||||
// idential to stream_, but must not be nullptr.
|
||||
se::Stream* device_to_host_stream_;
|
||||
// For the underlying memory allocator and XLA's TransferManager.
|
||||
xla::LocalClient* client_;
|
||||
// Transfer manager, for marshalling data to and from the device.
|
||||
@ -86,7 +95,9 @@ class XlaTransferManager {
|
||||
class XlaDeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit XlaDeviceContext(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
se::Stream* compute_stream, se::Stream* host_to_device_stream,
|
||||
se::Stream* device_to_host_stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
|
@ -23,9 +23,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cast_op.h"
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
#include "tensorflow/core/kernels/control_flow_ops.h"
|
||||
#include "tensorflow/core/kernels/fifo_queue.h"
|
||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
#include "tensorflow/core/kernels/no_op.h"
|
||||
#include "tensorflow/core/kernels/queue_op.h"
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
#include "tensorflow/core/kernels/shape_ops.h"
|
||||
@ -75,9 +77,7 @@ class XlaAssignVariableOp : public AsyncOpKernel {
|
||||
ConstantOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \
|
||||
IdentityNOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE), IdentityNOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
|
||||
PlaceholderOp); \
|
||||
@ -88,6 +88,9 @@ class XlaAssignVariableOp : public AsyncOpKernel {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
|
||||
ReadVariableOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \
|
||||
DestroyResourceOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("output") \
|
||||
@ -145,7 +148,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("output"), \
|
||||
LoopCondOp);
|
||||
LoopCondOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("size") \
|
||||
.HostMemory("handle"), \
|
||||
QueueSizeOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
|
||||
QueueIsClosedOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
|
||||
|
||||
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
// and write the tensors they access in order to concatenate them into a batch.
|
||||
// We would need either to call out to an XLA computation to perform the
|
||||
// concatenation, or we would need to refactor those kernels so the splitting
|
||||
// or merging is done in a separate operator that can be compiled.
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
@ -146,6 +147,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
|
||||
|
||||
std::unique_ptr<DeadnessAnalysis> deadness;
|
||||
TF_RETURN_IF_ERROR(DeadnessAnalysis::Run(graph, &deadness));
|
||||
|
||||
// Collect nodes that can be fused via XLA, while ignoring those that
|
||||
// explicitly ask for XLA: (*) nodes that are marked to be compiled
|
||||
// explicitly. (*) nodes assigned to XLA device.
|
||||
@ -185,6 +189,14 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
|
||||
continue;
|
||||
}
|
||||
|
||||
// If inputs to `node` can have conflicting deadness (i.e. some are alive
|
||||
// and some are dead) then don't compile it. XLA cannot represent the
|
||||
// deadness semantics of these nodes correctly and auto-clustering these
|
||||
// nodes can cause deadness to propagate to nodes that should be live.
|
||||
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
compilation_candidates.insert(node);
|
||||
}
|
||||
|
||||
|
@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
||||
name_prefix, registration,
|
||||
/*transfer_as_literal=*/false,
|
||||
/*use_multiple_streams=*/false,
|
||||
/*shape_representation_fn=*/{},
|
||||
/*padded_shape_fn=*/{}, &device);
|
||||
if (!status.ok()) {
|
||||
|
@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
|
||||
DEVICE_INTERPRETER_XLA_JIT, options,
|
||||
name_prefix, registration,
|
||||
/*transfer_as_literal=*/false,
|
||||
/*use_multiple_streams=*/false,
|
||||
/*shape_representation_fn=*/{},
|
||||
/*padded_shape_fn=*/{}, &device));
|
||||
devices->push_back(device.release());
|
||||
|
@ -64,11 +64,13 @@ xla::StatusOr<xla::OwningDeviceMemory> XlaAllocator::Allocate(
|
||||
int device_ordinal, uint64 size, bool retry_on_failure) {
|
||||
AllocationAttributes attrs;
|
||||
attrs.no_retry_on_failure = !retry_on_failure;
|
||||
void* data =
|
||||
wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
|
||||
if (data == nullptr) {
|
||||
return errors::ResourceExhausted("Out of memory while trying to allocate ",
|
||||
size, " bytes.");
|
||||
void* data = nullptr;
|
||||
if (size != 0) {
|
||||
data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs);
|
||||
if (data == nullptr) {
|
||||
return errors::ResourceExhausted(
|
||||
"Out of memory while trying to allocate ", size, " bytes.");
|
||||
}
|
||||
}
|
||||
return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size),
|
||||
device_ordinal, this);
|
||||
@ -115,14 +117,22 @@ using internal::ExtractSubShapedBuffer;
|
||||
|
||||
XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors)
|
||||
bool allocate_xla_tensors, bool use_multiple_streams)
|
||||
: client_(client),
|
||||
xla_allocator_(xla_allocator),
|
||||
allocate_xla_tensors_(allocate_xla_tensors) {}
|
||||
allocate_xla_tensors_(allocate_xla_tensors),
|
||||
use_multiple_streams_(use_multiple_streams) {
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
|
||||
"be allocating XLA tensors!";
|
||||
}
|
||||
}
|
||||
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
const std::map<int, OptionalTensor>& variables) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
|
||||
arg_buffers_.resize(kernel->xla_input_shapes.size());
|
||||
@ -140,6 +150,16 @@ void XlaComputationLaunchContext::PopulateInputs(
|
||||
t = &(ctx->input(arg_num));
|
||||
}
|
||||
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(stream) << "Must have a stream available when using XLA tensors!";
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor);
|
||||
if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
|
||||
stream->ThenWaitFor(event);
|
||||
xla_tensor->SetDefinedOn(stream);
|
||||
}
|
||||
}
|
||||
|
||||
const xla::Shape on_device_shape =
|
||||
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
||||
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
|
||||
@ -248,6 +268,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
|
||||
if (xla_tensor) {
|
||||
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
|
||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
|
||||
if (use_multiple_streams_) {
|
||||
se::Event event(stream->parent());
|
||||
CHECK(event.Init());
|
||||
stream->ThenRecordEvent(&event);
|
||||
xla_tensor->SetDefinedOn(stream, std::move(event));
|
||||
}
|
||||
} else {
|
||||
// xla_tensor wasn't valid, which must mean this is a zero-element
|
||||
// tensor.
|
||||
@ -302,6 +328,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->set_shaped_buffer(
|
||||
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
|
||||
if (use_multiple_streams_) {
|
||||
se::Event event(stream->parent());
|
||||
CHECK(event.Init());
|
||||
stream->ThenRecordEvent(&event);
|
||||
xla_tensor->SetDefinedOn(stream, std::move(event));
|
||||
}
|
||||
*variable->tensor() = output_tensor;
|
||||
} else {
|
||||
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
|
||||
|
@ -76,9 +76,15 @@ class XlaComputationLaunchContext {
|
||||
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
|
||||
// output tensors and variables are always XlaTensors. If false they are
|
||||
// assumed to be "normal" device pointers.
|
||||
// If 'use_multiple_streams' is true, tensors may be defined and used on
|
||||
// multiple streams and so se::Events must be defined and waited for. If
|
||||
// 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true
|
||||
// because we track inter-stream dependencies through events inside XlaTensor
|
||||
// objects.
|
||||
XlaComputationLaunchContext(xla::LocalClient* client,
|
||||
xla::DeviceMemoryAllocator* xla_allocator,
|
||||
bool allocate_xla_tensors);
|
||||
bool allocate_xla_tensors,
|
||||
bool use_multiple_streams);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||
@ -99,6 +105,7 @@ class XlaComputationLaunchContext {
|
||||
xla::LocalClient* client_;
|
||||
xla::DeviceMemoryAllocator* xla_allocator_;
|
||||
bool allocate_xla_tensors_;
|
||||
bool use_multiple_streams_;
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
|
||||
std::vector<xla::ShapedBuffer*> arg_ptrs_;
|
||||
};
|
||||
@ -115,7 +122,11 @@ class XlaTensorBuffer : public TensorBuffer {
|
||||
data_ = const_cast<void*>(ptr);
|
||||
}
|
||||
|
||||
~XlaTensorBuffer() override { allocator_->DeallocateRaw(data_); }
|
||||
~XlaTensorBuffer() override {
|
||||
if (data_) {
|
||||
allocator_->DeallocateRaw(data_);
|
||||
}
|
||||
}
|
||||
|
||||
void* data() const override { return data_; }
|
||||
size_t size() const override { return expected_size_; }
|
||||
|
@ -73,6 +73,34 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
|
||||
mutex_lock lock(mu_);
|
||||
if (!definition_event_.has_value()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// The set of defined streams is expected to be very small indeed (usually
|
||||
// 1-2), so a simple linear scan should be fast enough.
|
||||
if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
|
||||
stream) != streams_defined_on_.end()) {
|
||||
// stream is in streams_defined_on_; it doesn't need to be waited on.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return &*definition_event_;
|
||||
}
|
||||
|
||||
void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
|
||||
mutex_lock lock(mu_);
|
||||
definition_event_ = std::move(event);
|
||||
streams_defined_on_ = {stream};
|
||||
}
|
||||
|
||||
void XlaTensor::SetDefinedOn(se::Stream* stream) {
|
||||
mutex_lock lock(mu_);
|
||||
streams_defined_on_.push_back(stream);
|
||||
}
|
||||
|
||||
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
|
||||
// device-side tensors, which are either CPU or GPU memory pointers. This works
|
||||
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
|
||||
|
@ -85,6 +85,24 @@ class XlaTensor {
|
||||
host_tensor_.reset(new Tensor(tensor));
|
||||
}
|
||||
|
||||
// If the tensor's content is not yet defined on 'stream', and there exists an
|
||||
// se::Event declaring when the tensor's content is defined, return it.
|
||||
// Otherwise, return nullptr. If this function returns nullptr then the
|
||||
// tensor's content can be read on 'stream' without additional
|
||||
// synchronization.
|
||||
se::Event* GetDefinitionEvent(se::Stream* stream);
|
||||
|
||||
// Assert that the tensor's content is defined on 'stream' by the time 'event'
|
||||
// triggers.
|
||||
void SetDefinedOn(se::Stream* stream, se::Event event);
|
||||
|
||||
// Assert that the tensor's content is defined on 'stream'. This version does
|
||||
// not provide an event, and must be called *after* SetDefinedOn(Stream,
|
||||
// Event). This call can be read as an assertion that the definition event has
|
||||
// been waited on by 'stream', so further calls to GetDefinitionEvent(stream)
|
||||
// do not need to also wait on the event.
|
||||
void SetDefinedOn(se::Stream* stream);
|
||||
|
||||
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
|
||||
static XlaTensor* FromOpaquePointer(void* ptr);
|
||||
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
|
||||
@ -95,6 +113,14 @@ class XlaTensor {
|
||||
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
// An optional host tensor value.
|
||||
std::unique_ptr<Tensor> host_tensor_;
|
||||
// An optional event that is triggered when the tensor's content has been
|
||||
// defined. If this event is nullptr, it is assumed that the tensor's content
|
||||
// is always defined.
|
||||
gtl::optional<se::Event> definition_event_;
|
||||
// A list of all streams for which the tensor's content is defined for any
|
||||
// newly enqueued command.
|
||||
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -70,6 +70,19 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "adadelta_test",
|
||||
size = "medium",
|
||||
srcs = ["adadelta_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "adagrad_test",
|
||||
size = "small",
|
||||
@ -84,6 +97,19 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "adagrad_da_test",
|
||||
size = "small",
|
||||
srcs = ["adagrad_da_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "adam_test",
|
||||
size = "small",
|
||||
@ -98,6 +124,48 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "adamax_test",
|
||||
size = "small",
|
||||
srcs = ["adamax_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/contrib/opt:opt_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "addsign_test",
|
||||
size = "small",
|
||||
srcs = ["addsign_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/contrib/opt:opt_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "powersign_test",
|
||||
size = "small",
|
||||
srcs = ["powersign_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/contrib/opt:opt_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "argminmax_test",
|
||||
size = "small",
|
||||
@ -167,7 +235,7 @@ tf_xla_py_test(
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "cholesky_op_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["cholesky_op_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
@ -350,7 +418,7 @@ tf_xla_py_test(
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "eager_test",
|
||||
size = "small",
|
||||
size = "large",
|
||||
srcs = ["eager_test.py"],
|
||||
disabled_backends = [
|
||||
# TODO(b/78199195) Support XLA CPU devices in eager runtime
|
||||
@ -371,6 +439,20 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "fifo_queue_test",
|
||||
size = "medium",
|
||||
srcs = ["fifo_queue_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
@ -556,6 +638,53 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "proximal_adagrad_test",
|
||||
size = "medium",
|
||||
srcs = ["proximal_adagrad_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "proximal_gradient_descent_test",
|
||||
size = "medium",
|
||||
srcs = ["proximal_gradient_descent_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "qr_op_test",
|
||||
size = "medium",
|
||||
srcs = ["qr_op_test.py"],
|
||||
disabled_backends = [
|
||||
# Test is very slow on CPU.
|
||||
"cpu",
|
||||
"cpu_ondemand",
|
||||
],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "random_ops_test",
|
||||
size = "small",
|
||||
@ -871,8 +1000,10 @@ tf_xla_py_test(
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "sort_ops_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["sort_ops_test.py"],
|
||||
# Times out in fastbuild mode.
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tests:xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
|
134
tensorflow/compiler/tests/adadelta_test.py
Normal file
134
tensorflow/compiler/tests/adadelta_test.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Adadelta Optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adadelta
|
||||
|
||||
|
||||
class AdadeltaOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
num_updates = 4 # number of ADADELTA steps to perform
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
for grad in [0.2, 0.1, 0.01]:
|
||||
for lr in [1.0, 0.5, 0.1]:
|
||||
var0_init = [1.0, 2.0]
|
||||
var1_init = [3.0, 4.0]
|
||||
var0 = resource_variable_ops.ResourceVariable(
|
||||
var0_init, dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable(
|
||||
var1_init, dtype=dtype)
|
||||
|
||||
grads = constant_op.constant([grad, grad], dtype=dtype)
|
||||
|
||||
accum = 0.0
|
||||
accum_update = 0.0
|
||||
|
||||
# ADADELTA gradient optimizer
|
||||
rho = 0.95
|
||||
epsilon = 1e-8
|
||||
adadelta_opt = adadelta.AdadeltaOptimizer(
|
||||
learning_rate=lr, rho=rho, epsilon=epsilon)
|
||||
adadelta_update = adadelta_opt.apply_gradients(
|
||||
zip([grads, grads], [var0, var1]))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
opt_vars = adadelta_opt.variables()
|
||||
self.assertStartsWith(opt_vars[0].name, var0._shared_name)
|
||||
self.assertStartsWith(opt_vars[1].name, var0._shared_name)
|
||||
self.assertStartsWith(opt_vars[2].name, var1._shared_name)
|
||||
self.assertStartsWith(opt_vars[3].name, var1._shared_name)
|
||||
self.assertEqual(4, len(opt_vars))
|
||||
# Assign slots
|
||||
slot = [None] * 2
|
||||
slot_update = [None] * 2
|
||||
self.assertEqual(["accum", "accum_update"],
|
||||
adadelta_opt.get_slot_names())
|
||||
slot[0] = adadelta_opt.get_slot(var0, "accum")
|
||||
self.assertEquals(slot[0].get_shape(), var0.get_shape())
|
||||
self.assertFalse(slot[0] in variables.trainable_variables())
|
||||
|
||||
slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
|
||||
self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
|
||||
self.assertFalse(slot_update[0] in variables.trainable_variables())
|
||||
|
||||
slot[1] = adadelta_opt.get_slot(var1, "accum")
|
||||
self.assertEquals(slot[1].get_shape(), var1.get_shape())
|
||||
self.assertFalse(slot[1] in variables.trainable_variables())
|
||||
|
||||
slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
|
||||
self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
|
||||
self.assertFalse(slot_update[1] in variables.trainable_variables())
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose(var0_init, self.evaluate(var0))
|
||||
self.assertAllClose(var1_init, self.evaluate(var1))
|
||||
|
||||
update = [None] * num_updates
|
||||
tot_update = 0
|
||||
for step in range(num_updates):
|
||||
# Run adadelta update for comparison
|
||||
self.evaluate(adadelta_update)
|
||||
|
||||
# Perform initial update without previous accum values
|
||||
accum = accum * rho + (grad**2) * (1 - rho)
|
||||
update[step] = (
|
||||
np.sqrt(accum_update + epsilon) *
|
||||
(1. / np.sqrt(accum + epsilon)) * grad)
|
||||
accum_update = (
|
||||
accum_update * rho + (update[step]**2) * (1.0 - rho))
|
||||
tot_update += update[step] * lr
|
||||
|
||||
# Check that the accumulators have been updated
|
||||
for slot_idx in range(2):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([accum, accum], dtype=dtype),
|
||||
self.evaluate(slot[slot_idx]),
|
||||
rtol=1e-5)
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([accum_update, accum_update], dtype=dtype),
|
||||
self.evaluate(slot_update[slot_idx]),
|
||||
rtol=1e-5)
|
||||
|
||||
# Check that the parameters have been updated
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array(
|
||||
[var0_init[0] - tot_update, var0_init[1] - tot_update],
|
||||
dtype=dtype),
|
||||
self.evaluate(var0),
|
||||
rtol=1e-5)
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array(
|
||||
[var1_init[0] - tot_update, var1_init[1] - tot_update],
|
||||
dtype=dtype),
|
||||
self.evaluate(var1),
|
||||
rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
165
tensorflow/compiler/tests/adagrad_da_test.py
Normal file
165
tensorflow/compiler/tests/adagrad_da_test.py
Normal file
@ -0,0 +1,165 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for AdagradDA optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adagrad_da
|
||||
|
||||
|
||||
class AdagradDAOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdagradDAWithoutRegularizationBasic1(self):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||
opt = adagrad_da.AdagradDAOptimizer(
|
||||
3.0,
|
||||
global_step,
|
||||
initial_gradient_squared_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
||||
|
||||
# Run a step of AdagradDA
|
||||
update.run()
|
||||
|
||||
# Let g to be gradient accumulator, gg to be gradient squared
|
||||
# accumulator, T be the global step, lr is the learning rate, and k the
|
||||
# initial gradient squared accumulator value.
|
||||
# w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
|
||||
# For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
|
||||
# similarly for others.
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.904534, -1.603567]), var0.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.094821, -0.189358]), var1.eval())
|
||||
|
||||
def testAdagradDAwithoutRegularizationBasic2(self):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||
|
||||
opt = adagrad_da.AdagradDAOptimizer(
|
||||
3.0,
|
||||
global_step,
|
||||
initial_gradient_squared_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run a step of AdagradDA
|
||||
update.run()
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.904534, -1.603567]), var0.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.094821, -0.189358]), var1.eval())
|
||||
|
||||
def testAdagradDAWithL1(self):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||
|
||||
opt = adagrad_da.AdagradDAOptimizer(
|
||||
3.0,
|
||||
global_step,
|
||||
initial_gradient_squared_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.001,
|
||||
l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run a step of AdagradDA
|
||||
update.run()
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.895489, -1.59555]), var0.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.085339, -0.17989]), var1.eval())
|
||||
|
||||
def testAdagradDAWithL1_L2(self):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
global_step = resource_variable_ops.ResourceVariable(
|
||||
0, dtype=dtypes.int64)
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
|
||||
|
||||
opt = adagrad_da.AdagradDAOptimizer(
|
||||
3.0,
|
||||
global_step,
|
||||
initial_gradient_squared_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.001,
|
||||
l2_regularization_strength=2.0)
|
||||
update = opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]), global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
|
||||
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run a step of AdagradDA
|
||||
update.run()
|
||||
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.046907, -0.093659]), var0.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.004275, -0.009023]), var1.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adagrad
|
||||
|
||||
|
||||
class AdagradOptimizerTest(XLATestCase):
|
||||
class AdagradOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -48,7 +48,7 @@ def adam_update_numpy(param,
|
||||
return param_t, m_t, v_t
|
||||
|
||||
|
||||
class AdamOptimizerTest(XLATestCase):
|
||||
class AdamOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
|
139
tensorflow/compiler/tests/adamax_test.py
Normal file
139
tensorflow/compiler/tests/adamax_test.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for AdaMax optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib.opt.python.training import adamax
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def adamax_update_numpy(param,
|
||||
g_t,
|
||||
t,
|
||||
m,
|
||||
v,
|
||||
alpha=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8):
|
||||
m_t = beta1 * m + (1 - beta1) * g_t
|
||||
v_t = np.maximum(beta2 * v, np.abs(g_t))
|
||||
param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon))
|
||||
return param_t, m_t, v_t
|
||||
|
||||
|
||||
class AdaMaxOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
for i, dtype in enumerate(self.float_types):
|
||||
with self.test_session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype)
|
||||
|
||||
var0 = resource_variable_ops.ResourceVariable(
|
||||
var0_np, name="var0_%d" % i)
|
||||
var1 = resource_variable_ops.ResourceVariable(
|
||||
var1_np, name="var1_%d" % i)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
opt = adamax.AdaMaxOptimizer()
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
opt_variables = opt.variables()
|
||||
beta1_power = opt._get_beta_accumulators()
|
||||
self.assertTrue(beta1_power is not None)
|
||||
self.assertIn(beta1_power, opt_variables)
|
||||
|
||||
with ops.Graph().as_default():
|
||||
# Shouldn't return non-slot variables from other graphs.
|
||||
self.assertEqual(0, len(opt.variables()))
|
||||
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
beta1_power = opt._get_beta_accumulators()
|
||||
|
||||
# Run 3 steps of AdaMax
|
||||
for t in range(1, 4):
|
||||
update.run()
|
||||
|
||||
self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
|
||||
|
||||
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2)
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2)
|
||||
self.assertEqual("var0_%d/AdaMax:0" % (i,),
|
||||
opt.get_slot(var=var0, name="m").name)
|
||||
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype)
|
||||
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
opt = adamax.AdaMaxOptimizer(constant_op.constant(0.001))
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
beta1_power = opt._get_beta_accumulators()
|
||||
|
||||
# Run 3 steps of AdaMax
|
||||
for t in range(1, 4):
|
||||
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
|
||||
update.run()
|
||||
|
||||
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
|
||||
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
142
tensorflow/compiler/tests/addsign_test.py
Normal file
142
tensorflow/compiler/tests/addsign_test.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for AddSign."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib.opt.python.training import addsign
|
||||
from tensorflow.contrib.opt.python.training import sign_decay
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def py_linear_decay_fn(decay_steps):
|
||||
def linear_decay(step):
|
||||
step = min(step, decay_steps)
|
||||
return float(decay_steps - step) / decay_steps
|
||||
return linear_decay
|
||||
|
||||
|
||||
def addsign_update_numpy(params,
|
||||
g_t,
|
||||
m,
|
||||
lr,
|
||||
alpha=1.0,
|
||||
beta=0.9,
|
||||
py_sign_decay_fn=None,
|
||||
t=None):
|
||||
m_t = beta * m + (1 - beta) * g_t
|
||||
if py_sign_decay_fn is None:
|
||||
sign_decayed = 1.0
|
||||
else:
|
||||
sign_decayed = py_sign_decay_fn(t-1)
|
||||
multiplier = alpha + sign_decayed * np.sign(g_t) * np.sign(m_t)
|
||||
params_t = params - lr * multiplier * g_t
|
||||
return params_t, m_t
|
||||
|
||||
|
||||
class AddSignTest(xla_test.XLATestCase):
|
||||
|
||||
def _testDense(self,
|
||||
learning_rate=0.1,
|
||||
sign_decay_fn=None,
|
||||
py_sign_decay_fn=None,
|
||||
alpha=1.0,
|
||||
beta=0.9):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, m1 = 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype)
|
||||
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
opt = addsign.AddSignOptimizer(
|
||||
learning_rate=learning_rate,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
sign_decay_fn=sign_decay_fn,
|
||||
)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run 7 steps of AddSign
|
||||
# first 4 steps with positive gradient
|
||||
# last 3 steps with negative gradient (sign(gm) should be -1)
|
||||
for t in range(1, 8):
|
||||
if t < 5:
|
||||
update.run()
|
||||
else:
|
||||
neg_update.run()
|
||||
|
||||
var0_np, m0 = addsign_update_numpy(
|
||||
var0_np,
|
||||
grads0_np if t < 5 else -grads0_np,
|
||||
m0,
|
||||
learning_rate,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
py_sign_decay_fn=py_sign_decay_fn,
|
||||
t=t,
|
||||
)
|
||||
var1_np, m1 = addsign_update_numpy(
|
||||
var1_np,
|
||||
grads1_np if t < 5 else -grads1_np,
|
||||
m1,
|
||||
learning_rate,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
py_sign_decay_fn=py_sign_decay_fn,
|
||||
t=t,
|
||||
)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(
|
||||
var0_np, var0.eval(), half_rtol=1e-2)
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
def testDense(self):
|
||||
decay_steps = 10
|
||||
sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
|
||||
py_sign_decay_fn = py_linear_decay_fn(decay_steps)
|
||||
self._testDense()
|
||||
self._testDense(learning_rate=0.01, alpha=0.1, beta=0.8)
|
||||
self._testDense(
|
||||
sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class BinaryOpsTest(XLATestCase):
|
||||
class BinaryOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for binary operators."""
|
||||
|
||||
def _testBinary(self, op, a, b, expected, equality_test=None):
|
||||
@ -691,11 +691,13 @@ class BinaryOpsTest(XLATestCase):
|
||||
np.array([[10], [7], [2]], dtype=np.float32),
|
||||
np.float32(7),
|
||||
expected=np.array([[False], [False], [True]], dtype=np.bool))
|
||||
self._testBinary(
|
||||
less_op,
|
||||
np.array([[10], [7], [2], [-1]], dtype=np.int64),
|
||||
np.int64(7),
|
||||
expected=np.array([[False], [False], [True], [True]], dtype=np.bool))
|
||||
if np.int64 in self.numeric_types:
|
||||
self._testBinary(
|
||||
less_op,
|
||||
np.array([[10], [7], [2], [-1]], dtype=np.int64),
|
||||
np.int64(7),
|
||||
expected=np.array(
|
||||
[[False], [False], [True], [True]], dtype=np.bool))
|
||||
|
||||
for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]:
|
||||
self._testBinary(
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class BucketizationOpTest(XLATestCase):
|
||||
class BucketizationOpTest(xla_test.XLATestCase):
|
||||
|
||||
def testInt(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -22,7 +22,7 @@ import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
# TODO(srvasude): Merge this with
|
||||
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
|
||||
class CategoricalTest(XLATestCase):
|
||||
class CategoricalTest(xla_test.XLATestCase):
|
||||
"""Test cases for random-number generating operators."""
|
||||
|
||||
def output_dtypes(self):
|
||||
|
@ -18,12 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -32,7 +30,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class CholeskyOpTest(XLATestCase):
|
||||
class CholeskyOpTest(xla_test.XLATestCase):
|
||||
|
||||
# Cholesky defined for float64, float32, complex64, complex128
|
||||
# (https://www.tensorflow.org/api_docs/python/tf/cholesky)
|
||||
@ -103,9 +101,8 @@ class CholeskyOpTest(XLATestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.cholesky(tensor3)
|
||||
|
||||
@unittest.skip("Test is slow")
|
||||
def testLarge(self):
|
||||
n = 200
|
||||
def testLarge2000x2000(self):
|
||||
n = 2000
|
||||
shape = (n, n)
|
||||
data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag(
|
||||
np.ones(n).astype(np.float32))
|
||||
@ -128,6 +125,5 @@ class CholeskyOpTest(XLATestCase):
|
||||
matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype)
|
||||
self._verifyCholesky(matrix, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
|
||||
class ClusteringTest(XLATestCase):
|
||||
class ClusteringTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdd(self):
|
||||
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ConcatTest(XLATestCase):
|
||||
class ConcatTest(xla_test.XLATestCase):
|
||||
|
||||
def testHStack(self):
|
||||
with self.test_session():
|
||||
@ -292,7 +292,7 @@ class ConcatTest(XLATestCase):
|
||||
array_ops.concat([scalar, scalar, scalar], dim)
|
||||
|
||||
|
||||
class ConcatOffsetTest(XLATestCase):
|
||||
class ConcatOffsetTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase):
|
||||
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
|
||||
|
||||
|
||||
class PackTest(XLATestCase):
|
||||
class PackTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -26,23 +26,20 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import test_utils
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
DATA_FORMATS = (
|
||||
("_data_format_NHWC", "NHWC"),
|
||||
("_data_format_NCHW", "NCHW"),
|
||||
("_data_format_HWNC", "HWNC"),
|
||||
("_data_format_HWCN", "HWCN"),
|
||||
)
|
||||
|
||||
|
||||
class Conv2DTest(XLATestCase, parameterized.TestCase):
|
||||
class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _VerifyValues(self,
|
||||
input_sizes=None,
|
||||
@ -236,7 +233,7 @@ class Conv2DTest(XLATestCase, parameterized.TestCase):
|
||||
expected=np.reshape([108, 128], [1, 1, 1, 2]))
|
||||
|
||||
|
||||
class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
|
||||
class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _VerifyValues(self,
|
||||
input_sizes=None,
|
||||
@ -534,7 +531,7 @@ class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
|
||||
expected=[5, 0, 11, 0, 0, 0, 17, 0, 23])
|
||||
|
||||
|
||||
class Conv2DBackpropFilterTest(XLATestCase, parameterized.TestCase):
|
||||
class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _VerifyValues(self,
|
||||
input_sizes=None,
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
# Test cloned from
|
||||
# tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
|
||||
class Conv3DBackpropFilterV2GradTest(XLATestCase):
|
||||
class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
|
||||
|
||||
def testGradient(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase):
|
||||
|
||||
|
||||
# Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py
|
||||
class Conv3DTransposeTest(XLATestCase):
|
||||
class Conv3DTransposeTest(xla_test.XLATestCase):
|
||||
|
||||
def testConv3DTransposeSingleStride(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -114,7 +114,7 @@ def CheckGradConfigsToTest():
|
||||
yield i, f, o, s, p
|
||||
|
||||
|
||||
class DepthwiseConv2DTest(XLATestCase):
|
||||
class DepthwiseConv2DTest(xla_test.XLATestCase):
|
||||
|
||||
# This is testing that depthwise_conv2d and depthwise_conv2d_native
|
||||
# produce the same results. It also tests that NCHW and NWHC
|
||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DynamicUpdateSliceOpsTest(XLATestCase):
|
||||
class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertOpOutputMatchesExpected(self, op, args, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -20,14 +20,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class DynamicStitchTest(XLATestCase):
|
||||
class DynamicStitchTest(xla_test.XLATestCase):
|
||||
|
||||
def _AssertDynamicStitchResultIs(self, indices, data, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -40,7 +40,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import adam
|
||||
|
||||
|
||||
class EagerTest(XLATestCase):
|
||||
class EagerTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_scope():
|
||||
@ -286,7 +286,7 @@ class EagerTest(XLATestCase):
|
||||
[2.0, 2.0]], embedding_matrix.numpy())
|
||||
|
||||
|
||||
class EagerFunctionTest(XLATestCase):
|
||||
class EagerFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_scope():
|
||||
@ -403,7 +403,7 @@ class EagerFunctionTest(XLATestCase):
|
||||
def testSliceInDefun(self):
|
||||
with self.test_scope():
|
||||
|
||||
@function.defun(compiled=True)
|
||||
@function.defun
|
||||
def f(x, y):
|
||||
return x[0::2, y:, ...]
|
||||
|
||||
@ -418,8 +418,24 @@ class EagerFunctionTest(XLATestCase):
|
||||
self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
|
||||
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
|
||||
|
||||
def testNestedDefun(self):
|
||||
self.skipTest('Nested defuns do not work on TPU at the moment')
|
||||
with self.test_scope():
|
||||
|
||||
class ExcessivePaddingTest(XLATestCase):
|
||||
@function.defun
|
||||
def times_two(x):
|
||||
return 2 * x
|
||||
|
||||
@function.defun
|
||||
def two_x_plus_1(x):
|
||||
return times_two(x) + 1
|
||||
|
||||
x = constant_op.constant([2, 3, 4])
|
||||
y = two_x_plus_1(x)
|
||||
self.assertAllEqual([5, 7, 9], y.numpy())
|
||||
|
||||
|
||||
class ExcessivePaddingTest(xla_test.XLATestCase):
|
||||
"""Test that eager execution works with TPU flattened tensors.
|
||||
|
||||
Tensors that would normally be excessively padded when written
|
||||
@ -470,6 +486,36 @@ class ExcessivePaddingTest(XLATestCase):
|
||||
self.assertAllEqual(100 * [[36.0]], reduced)
|
||||
|
||||
|
||||
def multiple_tpus():
|
||||
devices = context.context().devices()
|
||||
return len([d for d in devices if 'device:TPU:' in d]) > 1
|
||||
|
||||
|
||||
class MultiDeviceTest(xla_test.XLATestCase):
|
||||
"""Test running TPU computation on more than one core."""
|
||||
|
||||
def testBasic(self):
|
||||
if not multiple_tpus():
|
||||
self.skipTest('MultiDeviceTest requires multiple TPU devices.')
|
||||
|
||||
# Compute 10 on TPU core 0
|
||||
with ops.device('device:TPU:0'):
|
||||
two = constant_op.constant(2)
|
||||
five = constant_op.constant(5)
|
||||
ten = two * five
|
||||
self.assertAllEqual(10, ten)
|
||||
|
||||
# Compute 6 on TPU core 1
|
||||
with ops.device('device:TPU:1'):
|
||||
two = constant_op.constant(2)
|
||||
three = constant_op.constant(3)
|
||||
six = two * three
|
||||
self.assertAllEqual(6, six)
|
||||
|
||||
# Copy 10 and 6 to CPU and sum them
|
||||
self.assertAllEqual(16, ten + six)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution(
|
||||
config=config_pb2.ConfigProto(log_device_placement=True))
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ExtractImagePatches(XLATestCase):
|
||||
class ExtractImagePatches(xla_test.XLATestCase):
|
||||
"""Functional tests for ExtractImagePatches op."""
|
||||
|
||||
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
|
||||
|
@ -17,14 +17,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxArgsTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxArgs operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase):
|
||||
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxArgsGradient operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
|
||||
bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxVars operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase):
|
||||
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
|
||||
class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxVarsGradient operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
|
@ -23,7 +23,7 @@ import itertools
|
||||
import numpy as np
|
||||
import scipy.signal as sps
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -58,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2))
|
||||
INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2))
|
||||
|
||||
|
||||
class FFTTest(XLATestCase):
|
||||
class FFTTest(xla_test.XLATestCase):
|
||||
|
||||
def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
|
||||
tf_method):
|
||||
|
201
tensorflow/compiler/tests/fifo_queue_test.py
Normal file
201
tensorflow/compiler/tests/fifo_queue_test.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes as dtypes_lib
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FIFOQueueTest(xla_test.XLATestCase):
|
||||
|
||||
def testEnqueue(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
enqueue_op = q.enqueue((10.0,))
|
||||
enqueue_op.run()
|
||||
|
||||
def testEnqueueWithShape(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
|
||||
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
|
||||
enqueue_correct_op.run()
|
||||
with self.assertRaises(ValueError):
|
||||
q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
|
||||
self.assertEqual(1, q.size().eval())
|
||||
|
||||
def testMultipleDequeues(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
|
||||
self.evaluate(q.enqueue([1]))
|
||||
self.evaluate(q.enqueue([2]))
|
||||
self.evaluate(q.enqueue([3]))
|
||||
a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
|
||||
self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
|
||||
|
||||
def testQueuesDontShare(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
|
||||
self.evaluate(q.enqueue(1))
|
||||
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
|
||||
self.evaluate(q2.enqueue(2))
|
||||
self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
|
||||
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
|
||||
|
||||
def testEnqueueDictWithoutNames(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
with self.assertRaisesRegexp(ValueError, "must have names"):
|
||||
q.enqueue({"a": 12.0})
|
||||
|
||||
def testParallelEnqueue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
# Run one producer thread for each element in elems.
|
||||
def enqueue(enqueue_op):
|
||||
sess.run(enqueue_op)
|
||||
|
||||
threads = [
|
||||
self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Dequeue every element using a single thread.
|
||||
results = []
|
||||
for _ in xrange(len(elems)):
|
||||
results.append(dequeued_t.eval())
|
||||
self.assertItemsEqual(elems, results)
|
||||
|
||||
def testParallelDequeue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
# Enqueue every element using a single thread.
|
||||
for enqueue_op in enqueue_ops:
|
||||
enqueue_op.run()
|
||||
|
||||
# Run one consumer thread for each element in elems.
|
||||
results = []
|
||||
|
||||
def dequeue():
|
||||
results.append(sess.run(dequeued_t))
|
||||
|
||||
threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
self.assertItemsEqual(elems, results)
|
||||
|
||||
def testDequeue(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
for enqueue_op in enqueue_ops:
|
||||
enqueue_op.run()
|
||||
|
||||
for i in xrange(len(elems)):
|
||||
vals = dequeued_t.eval()
|
||||
self.assertEqual([elems[i]], vals)
|
||||
|
||||
def testEnqueueAndBlockingDequeue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
|
||||
elems = [10.0, 20.0, 30.0]
|
||||
enqueue_ops = [q.enqueue((x,)) for x in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
def enqueue():
|
||||
# The enqueue_ops should run after the dequeue op has blocked.
|
||||
# TODO(mrry): Figure out how to do this without sleeping.
|
||||
time.sleep(0.1)
|
||||
for enqueue_op in enqueue_ops:
|
||||
sess.run(enqueue_op)
|
||||
|
||||
results = []
|
||||
|
||||
def dequeue():
|
||||
for _ in xrange(len(elems)):
|
||||
results.append(sess.run(dequeued_t))
|
||||
|
||||
enqueue_thread = self.checkedThread(target=enqueue)
|
||||
dequeue_thread = self.checkedThread(target=dequeue)
|
||||
enqueue_thread.start()
|
||||
dequeue_thread.start()
|
||||
enqueue_thread.join()
|
||||
dequeue_thread.join()
|
||||
|
||||
for elem, result in zip(elems, results):
|
||||
self.assertEqual([elem], result)
|
||||
|
||||
def testMultiEnqueueAndDequeue(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
|
||||
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
|
||||
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
|
||||
dequeued_t = q.dequeue()
|
||||
|
||||
for enqueue_op in enqueue_ops:
|
||||
enqueue_op.run()
|
||||
|
||||
for i in xrange(len(elems)):
|
||||
x_val, y_val = sess.run(dequeued_t)
|
||||
x, y = elems[i]
|
||||
self.assertEqual([x], x_val)
|
||||
self.assertEqual([y], y_val)
|
||||
|
||||
def testQueueSizeEmpty(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
self.assertEqual([0], q.size().eval())
|
||||
|
||||
def testQueueSizeAfterEnqueueAndDequeue(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
|
||||
enqueue_op = q.enqueue((10.0,))
|
||||
dequeued_t = q.dequeue()
|
||||
size = q.size()
|
||||
self.assertEqual([], size.get_shape())
|
||||
|
||||
enqueue_op.run()
|
||||
self.assertEqual(1, size.eval())
|
||||
dequeued_t.op.run()
|
||||
self.assertEqual(0, size.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
class FtrlOptimizerTest(XLATestCase):
|
||||
class FtrlOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def initVariableAndGradient(self, dtype):
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class FunctionTest(XLATestCase):
|
||||
class FunctionTest(xla_test.XLATestCase):
|
||||
|
||||
def testFunction(self):
|
||||
"""Executes a simple TensorFlow function."""
|
||||
|
@ -22,7 +22,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import test_utils
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import nn
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FusedBatchNormTest(XLATestCase, parameterized.TestCase):
|
||||
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def _reference_training(self, x, scale, offset, epsilon, data_format):
|
||||
if data_format != "NHWC":
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class GatherNdTest(XLATestCase):
|
||||
class GatherNdTest(xla_test.XLATestCase):
|
||||
|
||||
def _runGather(self, params, indices):
|
||||
with self.test_session():
|
||||
|
@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(
|
||||
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
|
||||
|
||||
def testGatherPrecision(self):
|
||||
with self.test_session() as session, self.test_scope():
|
||||
data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
|
||||
[0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
|
||||
indices = np.array([1, 2, 3, 1])
|
||||
dtype = dtypes.float32
|
||||
params_np = self._buildParams(data, dtype)
|
||||
params = array_ops.placeholder(dtype=dtype)
|
||||
indices_tf = constant_op.constant(indices)
|
||||
gather_t = array_ops.gather(params, indices_tf)
|
||||
gather_val = session.run(gather_t, feed_dict={params: params_np})
|
||||
np_val = params_np[indices]
|
||||
self.assertAllEqual(np_val, gather_val)
|
||||
|
||||
|
||||
class GatherBenchmark(test.Benchmark):
|
||||
"""Microbenchmarks for the gather op."""
|
||||
|
@ -25,7 +25,7 @@ import numpy as np
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape):
|
||||
return np.random.randint(0, 256, shape) / 256.
|
||||
|
||||
|
||||
class RGBToHSVTest(XLATestCase):
|
||||
class RGBToHSVTest(xla_test.XLATestCase):
|
||||
|
||||
def testBatch(self):
|
||||
# Build an arbitrary RGB image
|
||||
@ -104,7 +104,7 @@ class RGBToHSVTest(XLATestCase):
|
||||
self.assertAllCloseAccordingToType(hsv_tf, hsv_np)
|
||||
|
||||
|
||||
class AdjustContrastTest(XLATestCase):
|
||||
class AdjustContrastTest(xla_test.XLATestCase):
|
||||
|
||||
def _testContrast(self, x_np, y_np, contrast_factor):
|
||||
with self.test_session():
|
||||
@ -168,7 +168,7 @@ class AdjustContrastTest(XLATestCase):
|
||||
self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
class AdjustHueTest(XLATestCase):
|
||||
class AdjustHueTest(xla_test.XLATestCase):
|
||||
|
||||
def testAdjustNegativeHue(self):
|
||||
x_shape = [2, 2, 3]
|
||||
@ -303,7 +303,7 @@ class AdjustHueTest(XLATestCase):
|
||||
self._adjustHueTf(x_np, delta_h)
|
||||
|
||||
|
||||
class AdjustSaturationTest(XLATestCase):
|
||||
class AdjustSaturationTest(xla_test.XLATestCase):
|
||||
|
||||
def _adjust_saturation(self, image, saturation_factor):
|
||||
image = ops.convert_to_tensor(image, name="image")
|
||||
@ -403,7 +403,7 @@ class AdjustSaturationTest(XLATestCase):
|
||||
self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5)
|
||||
|
||||
|
||||
class ResizeBilinearTest(XLATestCase):
|
||||
class ResizeBilinearTest(xla_test.XLATestCase):
|
||||
|
||||
def _assertForwardOpMatchesExpected(self,
|
||||
image_np,
|
||||
|
@ -22,7 +22,7 @@ import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
# Local response normalization tests. The forward tests are copied from
|
||||
# tensorflow/python/kernel_tests/lrn_op_test.py
|
||||
class LRNTest(XLATestCase):
|
||||
class LRNTest(xla_test.XLATestCase):
|
||||
|
||||
def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
|
||||
beta=0.5):
|
||||
|
@ -19,14 +19,14 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MatrixBandPartTest(XLATestCase):
|
||||
class MatrixBandPartTest(xla_test.XLATestCase):
|
||||
|
||||
def _testMatrixBandPart(self, dtype, shape):
|
||||
with self.test_session():
|
||||
|
@ -22,7 +22,7 @@ import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -35,7 +35,7 @@ def MakePlaceholder(x):
|
||||
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
|
||||
|
||||
|
||||
class MatrixTriangularSolveOpTest(XLATestCase):
|
||||
class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
|
||||
|
||||
# MatrixTriangularSolve defined for float64, float32, complex64, complex128
|
||||
# (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import momentum as momentum_lib
|
||||
|
||||
|
||||
class MomentumOptimizerTest(XLATestCase):
|
||||
class MomentumOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
|
||||
var += accum * lr * momentum
|
||||
|
@ -22,14 +22,14 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class NAryOpsTest(XLATestCase):
|
||||
class NAryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testNAry(self, op, args, expected, equality_fn=None):
|
||||
with self.test_session() as session:
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class NullaryOpsTest(XLATestCase):
|
||||
class NullaryOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testNullary(self, op, expected):
|
||||
with self.test_session() as session:
|
||||
|
@ -18,14 +18,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class PlaceholderTest(XLATestCase):
|
||||
class PlaceholderTest(xla_test.XLATestCase):
|
||||
|
||||
def test_placeholder_with_default_default(self):
|
||||
with self.test_session() as sess, self.test_scope():
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding):
|
||||
padding=padding)
|
||||
|
||||
|
||||
class Pooling3DTest(XLATestCase):
|
||||
class Pooling3DTest(xla_test.XLATestCase):
|
||||
|
||||
def _VerifyValues(self, pool_func, input_sizes, window, strides, padding,
|
||||
expected):
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -69,7 +69,7 @@ def GetTestConfigs():
|
||||
return test_configs
|
||||
|
||||
|
||||
class PoolingTest(XLATestCase):
|
||||
class PoolingTest(xla_test.XLATestCase):
|
||||
|
||||
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
|
||||
data_format, expected):
|
||||
@ -288,7 +288,7 @@ class PoolingTest(XLATestCase):
|
||||
expected=expected_output)
|
||||
|
||||
|
||||
class PoolGradTest(XLATestCase):
|
||||
class PoolGradTest(xla_test.XLATestCase):
|
||||
|
||||
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
|
||||
|
142
tensorflow/compiler/tests/powersign_test.py
Normal file
142
tensorflow/compiler/tests/powersign_test.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for PowerSign."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.contrib.opt.python.training import powersign
|
||||
from tensorflow.contrib.opt.python.training import sign_decay
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def py_linear_decay_fn(decay_steps):
|
||||
def linear_decay(step):
|
||||
step = min(step, decay_steps)
|
||||
return float(decay_steps - step) / decay_steps
|
||||
return linear_decay
|
||||
|
||||
|
||||
def powersign_update_numpy(params,
|
||||
g_t,
|
||||
m,
|
||||
lr,
|
||||
base=math.e,
|
||||
beta=0.9,
|
||||
py_sign_decay_fn=None,
|
||||
t=None):
|
||||
m_t = beta * m + (1 - beta) * g_t
|
||||
if py_sign_decay_fn is None:
|
||||
sign_decayed = 1.0
|
||||
else:
|
||||
sign_decayed = py_sign_decay_fn(t-1)
|
||||
multiplier = base ** (sign_decayed * np.sign(g_t) * np.sign(m_t))
|
||||
params_t = params - lr * multiplier * g_t
|
||||
return params_t, m_t
|
||||
|
||||
|
||||
class PowerSignTest(xla_test.XLATestCase):
|
||||
|
||||
def _testDense(self,
|
||||
learning_rate=0.1,
|
||||
sign_decay_fn=None,
|
||||
py_sign_decay_fn=None,
|
||||
base=math.e,
|
||||
beta=0.9):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
# Initialize variables for numpy implementation.
|
||||
m0, m1 = 0.0, 0.0
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype)
|
||||
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
|
||||
opt = powersign.PowerSignOptimizer(
|
||||
learning_rate=learning_rate,
|
||||
base=base,
|
||||
beta=beta,
|
||||
sign_decay_fn=sign_decay_fn,
|
||||
)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
|
||||
global_step=global_step)
|
||||
|
||||
variables.global_variables_initializer().run()
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run 7 steps of powersign
|
||||
# first 4 steps with positive gradient
|
||||
# last 3 steps with negative gradient (sign(gm) should be -1)
|
||||
for t in range(1, 8):
|
||||
if t < 5:
|
||||
update.run()
|
||||
else:
|
||||
neg_update.run()
|
||||
|
||||
var0_np, m0 = powersign_update_numpy(
|
||||
var0_np,
|
||||
grads0_np if t < 5 else -grads0_np,
|
||||
m0,
|
||||
learning_rate,
|
||||
base=base,
|
||||
beta=beta,
|
||||
py_sign_decay_fn=py_sign_decay_fn,
|
||||
t=t,
|
||||
)
|
||||
var1_np, m1 = powersign_update_numpy(
|
||||
var1_np,
|
||||
grads1_np if t < 5 else -grads1_np,
|
||||
m1,
|
||||
learning_rate,
|
||||
base=base,
|
||||
beta=beta,
|
||||
py_sign_decay_fn=py_sign_decay_fn,
|
||||
t=t,
|
||||
)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
def testDense(self):
|
||||
decay_steps = 10
|
||||
sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
|
||||
py_sign_decay_fn = py_linear_decay_fn(decay_steps)
|
||||
self._testDense()
|
||||
self._testDense(learning_rate=0.1, base=10.0, beta=0.8)
|
||||
self._testDense(
|
||||
sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
172
tensorflow/compiler/tests/proximal_adagrad_test.py
Normal file
172
tensorflow/compiler/tests/proximal_adagrad_test.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Proximal Adagrad optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import adagrad
|
||||
from tensorflow.python.training import proximal_adagrad
|
||||
|
||||
|
||||
class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testResourceProximalAdagradwithoutRegularization(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
opt = proximal_adagrad.ProximalAdagradOptimizer(
|
||||
3.0,
|
||||
initial_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
||||
|
||||
# Run 3 steps Proximal Adagrad.
|
||||
for _ in range(3):
|
||||
update.run()
|
||||
|
||||
self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval())
|
||||
self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval())
|
||||
opt_vars = opt.variables()
|
||||
self.assertStartsWith(opt_vars[0].name, var0._shared_name)
|
||||
self.assertStartsWith(opt_vars[1].name, var1._shared_name)
|
||||
self.assertEqual(2, len(opt_vars))
|
||||
|
||||
def testProximalAdagradwithoutRegularization2(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
opt = proximal_adagrad.ProximalAdagradOptimizer(
|
||||
3.0,
|
||||
initial_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run 3 steps Proximal Adagrad.
|
||||
for _ in range(3):
|
||||
update.run()
|
||||
self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval())
|
||||
self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval())
|
||||
|
||||
def testProximalAdagradWithL1(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
opt = proximal_adagrad.ProximalAdagradOptimizer(
|
||||
3.0,
|
||||
initial_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.001,
|
||||
l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run 10 steps Proximal Adagrad
|
||||
for _ in range(10):
|
||||
update.run()
|
||||
self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval())
|
||||
self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval())
|
||||
|
||||
def testProximalAdagradWithL1_L2(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
opt = proximal_adagrad.ProximalAdagradOptimizer(
|
||||
3.0,
|
||||
initial_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.001,
|
||||
l2_regularization_strength=2.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run 10 steps Proximal Adagrad.
|
||||
for _ in range(10):
|
||||
update.run()
|
||||
|
||||
self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval())
|
||||
self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval())
|
||||
|
||||
def applyOptimizer(self, opt, steps=5):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run ProximalAdagrad for a few steps
|
||||
for _ in range(steps):
|
||||
update.run()
|
||||
|
||||
return var0.eval(), var1.eval()
|
||||
|
||||
def testEquivAdagradwithoutRegularization(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
val0, val1 = self.applyOptimizer(
|
||||
proximal_adagrad.ProximalAdagradOptimizer(
|
||||
3.0,
|
||||
initial_accumulator_value=0.1,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0))
|
||||
|
||||
with self.test_session(), self.test_scope():
|
||||
val2, val3 = self.applyOptimizer(
|
||||
adagrad.AdagradOptimizer(
|
||||
3.0, initial_accumulator_value=0.1))
|
||||
|
||||
self.assertAllClose(val0, val2)
|
||||
self.assertAllClose(val1, val3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
156
tensorflow/compiler/tests/proximal_gradient_descent_test.py
Normal file
156
tensorflow/compiler/tests/proximal_gradient_descent_test.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Proximal Gradient Descent optimizer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import proximal_gradient_descent
|
||||
|
||||
|
||||
class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
|
||||
|
||||
def testResourceProximalGradientDescentwithoutRegularization(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([0.0, 0.0], var0.eval())
|
||||
self.assertAllClose([0.0, 0.0], var1.eval())
|
||||
|
||||
# Run 3 steps Proximal Gradient Descent.
|
||||
for _ in range(3):
|
||||
update.run()
|
||||
|
||||
self.assertAllClose(np.array([-0.9, -1.8]), var0.eval())
|
||||
self.assertAllClose(np.array([-0.09, -0.18]), var1.eval())
|
||||
|
||||
def testProximalGradientDescentwithoutRegularization2(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run 3 steps Proximal Gradient Descent
|
||||
for _ in range(3):
|
||||
update.run()
|
||||
|
||||
self.assertAllClose(np.array([0.1, 0.2]), var0.eval())
|
||||
self.assertAllClose(np.array([3.91, 2.82]), var1.eval())
|
||||
|
||||
def testProximalGradientDescentWithL1(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0, l1_regularization_strength=0.001, l2_regularization_strength=0.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run 10 steps proximal gradient descent.
|
||||
for _ in range(10):
|
||||
update.run()
|
||||
|
||||
self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval())
|
||||
self.assertAllClose(np.array([3.67, 2.37]), var1.eval())
|
||||
|
||||
def testProximalGradientDescentWithL1_L2(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0, l1_regularization_strength=0.001, l2_regularization_strength=2.0)
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([4.0, 3.0], var1.eval())
|
||||
|
||||
# Run 10 steps Proximal Gradient Descent
|
||||
for _ in range(10):
|
||||
update.run()
|
||||
|
||||
self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval())
|
||||
self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval())
|
||||
|
||||
def applyOptimizer(self, opt, steps=5):
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
grads1 = constant_op.constant([0.01, 0.02])
|
||||
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run ProximalAdagrad for a few steps
|
||||
for _ in range(steps):
|
||||
update.run()
|
||||
|
||||
return var0.eval(), var1.eval()
|
||||
|
||||
def testEquivGradientDescentwithoutRegularization(self):
|
||||
with self.test_session(), self.test_scope():
|
||||
val0, val1 = self.applyOptimizer(
|
||||
proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0))
|
||||
|
||||
with self.test_session(), self.test_scope():
|
||||
val2, val3 = self.applyOptimizer(
|
||||
gradient_descent.GradientDescentOptimizer(3.0))
|
||||
|
||||
self.assertAllClose(val0, val2)
|
||||
self.assertAllClose(val1, val3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
115
tensorflow/compiler/tests/qr_op_test.py
Normal file
115
tensorflow/compiler/tests/qr_op_test.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def AdjustedNorm(self, x):
|
||||
"""Computes the norm of matrices in 'x', adjusted for dimension and type."""
|
||||
norm = np.linalg.norm(x, axis=(-2, -1))
|
||||
return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps)
|
||||
|
||||
def CompareOrthogonal(self, x, y, rank):
|
||||
# We only compare the first 'rank' orthogonal vectors since the
|
||||
# remainder form an arbitrary orthonormal basis for the
|
||||
# (row- or column-) null space, whose exact value depends on
|
||||
# implementation details. Notice that since we check that the
|
||||
# matrices of singular vectors are unitary elsewhere, we do
|
||||
# implicitly test that the trailing vectors of x and y span the
|
||||
# same space.
|
||||
x = x[..., 0:rank]
|
||||
y = y[..., 0:rank]
|
||||
# Q is only unique up to sign (complex phase factor for complex matrices),
|
||||
# so we normalize the sign first.
|
||||
sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
|
||||
phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
|
||||
x *= phases
|
||||
self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0))
|
||||
|
||||
def CheckApproximation(self, a, q, r):
|
||||
# Tests that a ~= q*r.
|
||||
precision = self.AdjustedNorm(a - np.matmul(q, r))
|
||||
self.assertTrue(np.all(precision < 10.0))
|
||||
|
||||
def CheckUnitary(self, x):
|
||||
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
|
||||
xx = math_ops.matmul(x, x, adjoint_a=True)
|
||||
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
|
||||
precision = self.AdjustedNorm(xx.eval() - identity.eval())
|
||||
self.assertTrue(np.all(precision < 5.0))
|
||||
|
||||
def _test(self, dtype, shape, full_matrices):
|
||||
np.random.seed(1)
|
||||
x_np = np.random.uniform(
|
||||
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
|
||||
|
||||
with self.test_session() as sess:
|
||||
x_tf = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
|
||||
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
|
||||
|
||||
q_dims = q_tf_val.shape
|
||||
np_q = np.ndarray(q_dims, dtype)
|
||||
np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
|
||||
new_first_dim = np_q_reshape.shape[0]
|
||||
|
||||
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
|
||||
for i in range(new_first_dim):
|
||||
if full_matrices:
|
||||
np_q_reshape[i, :, :], _ = np.linalg.qr(
|
||||
x_reshape[i, :, :], mode="complete")
|
||||
else:
|
||||
np_q_reshape[i, :, :], _ = np.linalg.qr(
|
||||
x_reshape[i, :, :], mode="reduced")
|
||||
np_q = np.reshape(np_q_reshape, q_dims)
|
||||
self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:]))
|
||||
self.CheckApproximation(x_np, q_tf_val, r_tf_val)
|
||||
self.CheckUnitary(q_tf_val)
|
||||
|
||||
SIZES = [1, 2, 5, 10, 32, 100, 300]
|
||||
DTYPES = [np.float32]
|
||||
PARAMS = itertools.product(SIZES, SIZES, DTYPES)
|
||||
|
||||
@parameterized.parameters(*PARAMS)
|
||||
def testQR(self, rows, cols, dtype):
|
||||
# TODO(b/111317468): implement full_matrices=False, test other types.
|
||||
for full_matrices in [True]:
|
||||
# Only tests the (3, 2) case for small numbers of rows/columns.
|
||||
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
|
||||
self._test(dtype, batch_dims + (rows, cols), full_matrices)
|
||||
|
||||
def testLarge2000x2000(self):
|
||||
self._test(np.float32, (2000, 2000), full_matrices=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -22,7 +22,7 @@ import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class RandomOpsTest(XLATestCase):
|
||||
class RandomOpsTest(xla_test.XLATestCase):
|
||||
"""Test cases for random-number generating operators."""
|
||||
|
||||
def _random_types(self):
|
||||
@ -140,10 +140,10 @@ class RandomOpsTest(XLATestCase):
|
||||
def testShuffle1d(self):
|
||||
with self.test_session() as sess:
|
||||
with self.test_scope():
|
||||
x = math_ops.range(20)
|
||||
x = math_ops.range(1 << 16)
|
||||
shuffle = random_ops.random_shuffle(x)
|
||||
result = sess.run(shuffle)
|
||||
expected = range(20)
|
||||
expected = range(1 << 16)
|
||||
# Compare sets to avoid randomness behavior changes but make sure still
|
||||
# have all the values.
|
||||
self.assertAllEqual(set(result), set(expected))
|
||||
|
@ -22,7 +22,7 @@ import functools
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReduceOpsTest(XLATestCase):
|
||||
class ReduceOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def _testReduction(self,
|
||||
tf_reduce_fn,
|
||||
@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase):
|
||||
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
|
||||
|
||||
|
||||
class ReduceOpPrecisionTest(XLATestCase):
|
||||
class ReduceOpPrecisionTest(xla_test.XLATestCase):
|
||||
|
||||
def _testReduceSum(self,
|
||||
expected_result,
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReduceWindowTest(XLATestCase):
|
||||
class ReduceWindowTest(xla_test.XLATestCase):
|
||||
"""Test cases for xla.reduce_window."""
|
||||
|
||||
def _reduce_window(self, operand, init, reducer, **kwargs):
|
||||
|
@ -21,14 +21,14 @@ from __future__ import print_function
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class ReverseOpsTest(XLATestCase):
|
||||
class ReverseOpsTest(xla_test.XLATestCase):
|
||||
|
||||
def testReverseOneDim(self):
|
||||
shape = (7, 5, 9, 11)
|
||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReverseSequenceTest(XLATestCase):
|
||||
class ReverseSequenceTest(xla_test.XLATestCase):
|
||||
|
||||
def _testReverseSequence(self,
|
||||
x,
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -28,33 +28,104 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import rmsprop
|
||||
|
||||
|
||||
class RmspropTest(XLATestCase):
|
||||
class RmspropTest(xla_test.XLATestCase):
|
||||
|
||||
def _rmsprop_update_numpy(self,
|
||||
var,
|
||||
g,
|
||||
mg,
|
||||
rms,
|
||||
mom,
|
||||
lr,
|
||||
decay=0.9,
|
||||
momentum=0.0,
|
||||
epsilon=1e-10,
|
||||
centered=False):
|
||||
rms_t = rms * decay + (1 - decay) * g * g
|
||||
denom_t = rms_t + epsilon
|
||||
if centered:
|
||||
mg_t = mg * decay + (1 - decay) * g
|
||||
denom_t -= mg_t * mg_t
|
||||
else:
|
||||
mg_t = mg
|
||||
mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
|
||||
var_t = var - mom_t
|
||||
return var_t, mg_t, rms_t, mom_t
|
||||
|
||||
def testBasic(self):
|
||||
for dtype in self.float_types:
|
||||
with self.test_session(), self.test_scope():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||
rms_opt = rmsprop.RMSPropOptimizer(3.0)
|
||||
rms_update = rms_opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
for centered in [False, True]:
|
||||
with self.test_session(), self.test_scope():
|
||||
# Initialize variables for numpy implementation.
|
||||
var0_np = np.array([1.0, 2.0], dtype=dtype)
|
||||
grads0_np = np.array([0.1, 0.1], dtype=dtype)
|
||||
var1_np = np.array([3.0, 4.0], dtype=dtype)
|
||||
grads1_np = np.array([0.01, 0.01], dtype=dtype)
|
||||
mg0_np = np.array([0.0, 0.0], dtype=dtype)
|
||||
mg1_np = np.array([0.0, 0.0], dtype=dtype)
|
||||
rms0_np = np.array([1.0, 1.0], dtype=dtype)
|
||||
rms1_np = np.array([1.0, 1.0], dtype=dtype)
|
||||
mom0_np = np.array([0.0, 0.0], dtype=dtype)
|
||||
mom1_np = np.array([0.0, 0.0], dtype=dtype)
|
||||
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||
grads0 = constant_op.constant(grads0_np)
|
||||
grads1 = constant_op.constant(grads1_np)
|
||||
learning_rate = 3.0
|
||||
rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
|
||||
rms_update = rms_opt.apply_gradients(
|
||||
zip([grads0, grads1], [var0, var1]))
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
# Run 3 steps of RMSProp
|
||||
for _ in range(3):
|
||||
rms_update.run()
|
||||
mg0 = rms_opt.get_slot(var0, "mg")
|
||||
self.assertEqual(mg0 is not None, centered)
|
||||
mg1 = rms_opt.get_slot(var1, "mg")
|
||||
self.assertEqual(mg1 is not None, centered)
|
||||
rms0 = rms_opt.get_slot(var0, "rms")
|
||||
self.assertTrue(rms0 is not None)
|
||||
rms1 = rms_opt.get_slot(var1, "rms")
|
||||
self.assertTrue(rms1 is not None)
|
||||
mom0 = rms_opt.get_slot(var0, "momentum")
|
||||
self.assertTrue(mom0 is not None)
|
||||
mom1 = rms_opt.get_slot(var1, "momentum")
|
||||
self.assertTrue(mom1 is not None)
|
||||
|
||||
# Validate updated params
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([2.91705132e-04, 1.00029182e+00]), var0.eval())
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([2.89990854, 3.89990854]), var1.eval())
|
||||
# Fetch params to validate initial values
|
||||
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||
|
||||
# Run 3 steps of RMSProp
|
||||
for _ in range(3):
|
||||
rms_update.run()
|
||||
|
||||
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
|
||||
var0_np,
|
||||
grads0_np,
|
||||
mg0_np,
|
||||
rms0_np,
|
||||
mom0_np,
|
||||
learning_rate,
|
||||
centered=centered)
|
||||
var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
|
||||
var1_np,
|
||||
grads1_np,
|
||||
mg1_np,
|
||||
rms1_np,
|
||||
mom1_np,
|
||||
learning_rate,
|
||||
centered=centered)
|
||||
|
||||
# Validate updated params
|
||||
if centered:
|
||||
self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
|
||||
self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
|
||||
self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
|
||||
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
|
||||
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
|
||||
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
|
||||
self.assertAllCloseAccordingToType(var0_np, var0.eval())
|
||||
self.assertAllCloseAccordingToType(var1_np, var1.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse):
|
||||
return x
|
||||
|
||||
|
||||
class CumsumTest(XLATestCase):
|
||||
class CumsumTest(xla_test.XLATestCase):
|
||||
|
||||
valid_dtypes = [np.float32]
|
||||
|
||||
@ -147,7 +147,7 @@ class CumsumTest(XLATestCase):
|
||||
math_ops.cumsum(input_tensor, [0]).eval()
|
||||
|
||||
|
||||
class CumprodTest(XLATestCase):
|
||||
class CumprodTest(xla_test.XLATestCase):
|
||||
|
||||
valid_dtypes = [np.float32]
|
||||
|
||||
|
@ -22,7 +22,7 @@ import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape):
|
||||
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
|
||||
|
||||
|
||||
class ScatterNdTest(XLATestCase):
|
||||
class ScatterNdTest(xla_test.XLATestCase):
|
||||
|
||||
def _VariableRankTest(self,
|
||||
np_scatter,
|
||||
|
@ -18,14 +18,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class SliceTest(XLATestCase):
|
||||
class SliceTest(xla_test.XLATestCase):
|
||||
|
||||
def test1D(self):
|
||||
for dtype in self.numeric_types:
|
||||
@ -110,7 +110,7 @@ class SliceTest(XLATestCase):
|
||||
self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result)
|
||||
|
||||
|
||||
class StridedSliceTest(XLATestCase):
|
||||
class StridedSliceTest(xla_test.XLATestCase):
|
||||
|
||||
def test1D(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -64,20 +64,61 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
if self.device in ["XLA_CPU", "XLA_GPU"]:
|
||||
return
|
||||
|
||||
# Only bfloat16 is implemented.
|
||||
bfloat16 = dtypes.bfloat16.as_numpy_dtype
|
||||
if bfloat16 in self.numeric_types:
|
||||
for x in [np.arange(20)]:
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||
array_size = 20
|
||||
k_options = [0, 1, 2, 10, 20]
|
||||
else:
|
||||
array_size = 200 * 1000
|
||||
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
|
||||
for x in [np.arange(array_size)]:
|
||||
np.random.shuffle(x)
|
||||
for k in [0, 1, 2, 10, 20]:
|
||||
for k in k_options:
|
||||
indices = x.argsort()[::-1][:k]
|
||||
|
||||
def topk(v, k=k):
|
||||
return nn_ops.top_k(v, k=k, sorted=True)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
topk, [x.astype(bfloat16)],
|
||||
expected=[x[indices].astype(bfloat16), indices])
|
||||
topk, [x.astype(dtype)],
|
||||
expected=[x[indices].astype(dtype), indices])
|
||||
|
||||
def testTopK2D(self):
|
||||
# TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
|
||||
if self.device in ["XLA_CPU", "XLA_GPU"]:
|
||||
return
|
||||
|
||||
supported_types = set(
|
||||
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
|
||||
for dtype in supported_types.intersection(self.numeric_types):
|
||||
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
|
||||
# after conversion to bfloat16, so the possible resulting index array is
|
||||
# no longer unique.
|
||||
if dtype == dtypes.bfloat16.as_numpy_dtype:
|
||||
array_size = 10
|
||||
k_options = [0, 1, 2, 10]
|
||||
else:
|
||||
array_size = 200 * 1000
|
||||
k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
|
||||
batch = 16
|
||||
for x in [np.arange(batch * array_size)]:
|
||||
np.random.shuffle(x)
|
||||
x = np.reshape(x, [batch, array_size])
|
||||
for k in k_options:
|
||||
indices = x.argsort(axis=1)[::, -1:-k - 1:-1]
|
||||
expected = np.sort(x, axis=1)[::, -1:-k - 1:-1]
|
||||
|
||||
def topk(v, k=k):
|
||||
return nn_ops.top_k(v, k=k, sorted=True)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
topk, [x.astype(dtype)],
|
||||
expected=[expected.astype(dtype), indices])
|
||||
|
||||
def testTopKZeros(self):
|
||||
"""Tests that positive and negative zeros sort correctly."""
|
||||
@ -99,7 +140,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
|
||||
self.assertAllEqual(
|
||||
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
|
||||
self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
|
||||
self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
|
||||
|
||||
def testTopKInfinities(self):
|
||||
"""Tests that positive and negative infinity sort correctly."""
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user