Merge branch 'master' into ecosystem-jars

This commit is contained in:
Soila Kavulya 2018-06-12 16:09:31 -07:00
commit 34bf5f4b8e
1197 changed files with 53078 additions and 24233 deletions
CONTRIBUTING.mdREADME.mdRELEASE.mdSECURITY.mdWORKSPACEconfigure.py
tensorflow
BUILD__init__.pyapi_template.__init__.py
c
cc
framework
gradients
compiler
aot/tests
jit
tests
tf2xla
xla

View File

@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g.,
Changes to TensorFlow C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do:
Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy

View File

@ -56,6 +56,7 @@ $ python
42
>>> sess.close()
```
Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/).
## Contribution guidelines

View File

@ -1,3 +1,62 @@
# Release 1.9.0
## Major Features And Improvements
* Update tf.keras to the Keras 2.1.6 API.
* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`.
* Adding support of core feature columns and losses to gradient boosted trees estimators.
* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details.
* Layered variable names have changed in the following conditions:
* Using `tf.keras.layers` with custom variable scopes.
* Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details
## Breaking Chances
* If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...).
## Bug Fixes and Other Changes
* `tf.data`:
* The `DatasetBase::DebugString()` method is now `const`.
* Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets.
* Eager Execution:
* `tf.keras`:
* Move Keras code out of _impl folder and remove API files.
* `tf.keras.Model.save_weights` now saves in TensorFlow format by default.
* Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods.
* Accelerated Linear Algebra (XLA):
* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB).
* `tf.contrib`:
* Add `tf.contrib.data.choose_from_datasets()`.
* `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`.
* `tf.contrib.framework.zero_initializer` supports ResourceVariable.
* Adding "constrained_optimization" to tensorflow/contrib.
* Other:
* Add GCS Configuration Ops.
* Changing signature of `MakeIterator` to enable propagating error status.
* KL divergence for two Dirichlet distributions.
* More consistent GcsFileSystem behavior for certain reads past EOF.
* Update benchmark for tf.scan to match ranges across eager and graph modes.
* Fixed bug in `tf.reduce_prod gradient` for complex dtypes.
* Add optional `args` argument to `Dataset.from_generator()`.
* Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)").
* Benchmark for tf.scan in graph and eager modes.
* Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D.
* Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch.
* Support indicator column in boosted trees.
* Prevent `tf.gradients()` from backpropagating through integer tensors.
* LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`.
* Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary.
* Added `tf.train.Checkpoint` for reading/writing object-based checkpoints.
* `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
* Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product.
* Allow LinearOperator to broadcast.
* SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other.
## Thanks to our Contributors
This release contains contributions from many people at Google, as well as:
Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang
# Release 1.8.0
## Major Features And Improvements
@ -404,14 +463,6 @@ answered questions, and were part of inspiring discussions.
# Release 1.4.0
## Major Features And Improvements
* `tf.keras` is now part of the core TensorFlow API.
* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of
the core TensorFlow API.
* The API is now subject to backwards compatibility guarantees.
# Release 1.4.0
## Major Features And Improvements
* `tf.keras` is now part of the core TensorFlow API.
* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of

View File

@ -242,12 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
-----END PGP PUBLIC KEY BLOCK-----
```
### Known vulnerabilities
| Type | Versions affected | Reported by | Additional Information |
|--------------------|:-----------------:|-----------------------|-----------------------------|
| TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-003.md) |
| GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-002.md) |
| BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-001.md) |
| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
### Known Vulnerabilities
For a list of known vulnerabilities and security advisories for TensorFlow,
(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here].

View File

@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
# name = "androidsdk",
# api_level = 23,
# # Ensure that you have the build_tools_version below installed in the
# # SDK manager as it updates periodically.
# build_tools_version = "26.0.1",
# # Replace with path to Android SDK on your system
# path = "<PATH_TO_SDK>",
#)
#
#android_ndk_repository(
# name="androidndk",
# path="<PATH_TO_NDK>",
# # This needs to be 14 or higher to compile TensorFlow.
# # Please specify API level to >= 21 to build for 64-bit
# # archtectures or the Android NDK will automatically select biggest
# # API level that it supports without notice.
# # Note that the NDK version is not the API level.
# api_level=14)
load("//third_party/android:android_configure.bzl", "android_configure")
android_configure(name="local_config_android")
load("@local_config_android//:android.bzl", "android_workspace")
android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()

View File

@ -670,8 +670,9 @@ def create_android_ndk_rule(environ_cp):
error_msg=('The path %s or its child file "source.properties" '
'does not exist.')
)
write_android_ndk_workspace_rule(android_ndk_home_path)
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
check_ndk_level(android_ndk_home_path))
def create_android_sdk_rule(environ_cp):
@ -733,41 +734,12 @@ def create_android_sdk_rule(environ_cp):
error_msg=('The selected SDK does not have build-tools version %s '
'available.'))
write_android_sdk_workspace_rule(android_sdk_home_path,
android_build_tools_version,
android_api_level)
def write_android_sdk_workspace_rule(android_sdk_home_path,
android_build_tools_version,
android_api_level):
print('Writing android_sdk_workspace rule.\n')
with open(_TF_WORKSPACE, 'a') as f:
f.write("""
android_sdk_repository(
name="androidsdk",
api_level=%s,
path="%s",
build_tools_version="%s")\n
""" % (android_api_level, android_sdk_home_path, android_build_tools_version))
def write_android_ndk_workspace_rule(android_ndk_home_path):
print('Writing android_ndk_workspace rule.')
ndk_api_level = check_ndk_level(android_ndk_home_path)
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The API level of the NDK in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_api_level,
_SUPPORTED_ANDROID_NDK_VERSIONS))
with open(_TF_WORKSPACE, 'a') as f:
f.write("""
android_ndk_repository(
name="androidndk",
path="%s",
api_level=%s)\n
""" % (android_ndk_home_path, ndk_api_level))
write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION',
android_build_tools_version)
write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL',
android_api_level)
write_action_env_to_bazelrc('ANDROID_SDK_HOME',
android_sdk_home_path)
def check_ndk_level(android_ndk_home_path):
@ -780,18 +752,16 @@ def check_ndk_level(android_ndk_home_path):
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
if revision:
return revision.group(1)
return None
def workspace_has_any_android_rule():
"""Check the WORKSPACE for existing android_*_repository rules."""
with open(_TF_WORKSPACE, 'r') as f:
workspace = f.read()
has_any_rule = re.search(r'^android_[ns]dk_repository',
workspace,
re.MULTILINE)
return has_any_rule
ndk_api_level = revision.group(1)
else:
raise Exception('Unable to parse NDK revision.')
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The API level of the NDK in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_api_level,
_SUPPORTED_ANDROID_NDK_VERSIONS))
return ndk_api_level
def set_gcc_host_compiler_path(environ_cp):
@ -1223,7 +1193,7 @@ def set_tf_cuda_compute_capabilities(environ_cp):
# Check whether all capabilities from the input is valid
all_valid = True
# Remove all whitespace characters before splitting the string
# that users may insert by accident, as this will result in error
# that users may insert by accident, as this will result in error
tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split())
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
@ -1427,6 +1397,10 @@ def set_grpc_build_flags():
write_to_bazelrc('build --define grpc_no_ares=true')
def set_build_strip_flag():
write_to_bazelrc('build --strip=always')
def set_windows_build_flags():
if is_windows():
# The non-monolithic build is not supported yet
@ -1549,23 +1523,18 @@ def main():
set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
set_build_strip_flag()
set_windows_build_flags()
if workspace_has_any_android_rule():
print('The WORKSPACE file has at least one of ["android_sdk_repository", '
'"android_ndk_repository"] already set. Will not ask to help '
'configure the WORKSPACE. Please delete the existing rules to '
'activate the helper.\n')
else:
if get_var(
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
False,
('Would you like to interactively configure ./WORKSPACE for '
'Android builds?'),
'Searching for NDK and SDK installations.',
'Not configuring the WORKSPACE for Android builds.'):
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
if get_var(
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
False,
('Would you like to interactively configure ./WORKSPACE for '
'Android builds?'),
'Searching for NDK and SDK installations.',
'Not configuring the WORKSPACE for Android builds.'):
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
print('Preconfigured Bazel build configs. You can use any of the below by '
'adding "--config=<>" to your build command. See tools/bazel.rc for '

View File

@ -19,6 +19,10 @@ load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_additional_binary_deps",
)
load(
"//tensorflow/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
# Config setting for determining if we are building for Android.
config_setting(
@ -471,7 +475,7 @@ tf_cc_shared_object(
# excludes all but a subset of function names.
# On MacOS, the linker does not support version_script, but has an
# an "-exported_symbols_list" command. -z defs disallows undefined
# symbols in object files and -s strips the output.
# symbols in object files.
tf_cc_shared_object(
name = "libtensorflow.so",
@ -485,7 +489,6 @@ tf_cc_shared_object(
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-s",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"$(location //tensorflow/c:version_script.lds)",
],
@ -511,7 +514,6 @@ tf_cc_shared_object(
"//tensorflow:windows_msvc": [],
"//conditions:default": [
"-z defs",
"-s",
"-Wl,--version-script", # This line must be directly followed by the version_script.lds file
"$(location //tensorflow:tf_version_script.lds)",
],
@ -536,13 +538,19 @@ exports_files(
],
)
gen_api_init_files(
name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
root_init_template = "api_template.__init__.py",
)
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
srcs = [
":tensorflow_python_api_gen",
"//tensorflow/python/estimator/api:estimator_python_api_gen",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/python",
"//tensorflow/tools/api/generator:python_api",
],
deps = ["//tensorflow/python"],
)

View File

@ -22,9 +22,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# pylint: disable=wildcard-import
from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')

View File

@ -0,0 +1,58 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# API IMPORTS PLACEHOLDER
try:
import os # pylint: disable=g-import-not-at-top
# Add `estimator` attribute to allow access to estimator APIs via
# "tf.estimator..."
from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
# Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
# style imports.
from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
__path__ += [os.path.dirname(estimator_api.__file__)]
del estimator_api
del os
except (ImportError, AttributeError):
print('tf.estimator package not installed.')
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
del absolute_import
del division
del print_function
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
del python
del core
# pylint: enable=undefined-variable

View File

@ -631,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
"Failed to allocate memory to serialize message of type '",
in.GetTypeName(), "' and size ", proto_size);
}
in.SerializeToArray(buf, proto_size);
// SerializeToArray takes size as an int.
// This next 'if' is a workaround till we update to depend on a version
// of protocol buffers that includes
// https://github.com/google/protobuf/pull/4739
if (proto_size > std::numeric_limits<int>::max()) {
return InvalidArgument("Cannot serialize protocol buffer of type ",
in.GetTypeName(), " as the serialized size (",
proto_size,
"bytes) would be larger than the limit (",
std::numeric_limits<int>::max(), " bytes)");
}
if (!in.SerializeToArray(buf, proto_size)) {
return InvalidArgument("Unable to serialize ", in.GetTypeName(),
" protocol buffer, perhaps the serialized size (",
proto_size, " bytes) is too large?");
}
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {

View File

@ -142,8 +142,10 @@ void TestRemoteExecute(bool async) {
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
void TestRemoteExecuteSilentCopies(bool async) {
tensorflow::ServerDef server_def = GetServerDef(2);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
ASSERT_TRUE(
tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
.ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
const char remote_device_name[] =
"/job:localhost/replica:0/task:1/device:CPU:0";
// Handles are on task0, but op is on remote (task1).
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0);
TFE_OpSetDevice(matmul, remote_device_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteOp(matmul);
TFE_ContextAsyncWait(ctx, status);
TFE_DeleteContext(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
// TODO(nareshmodi): Figure out how to correctly shut the server down.
worker_server.release();
}
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(true);
}
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));

View File

@ -15,10 +15,12 @@
# ==============================================================================
TF_PREFIX='/usr/local'
LIBDIR='lib'
usage() {
echo "Usage: $0 OPTIONS"
echo -e "-p, --prefix\tset installation prefix (default: /usr/local)"
echo -e "-l, --libdir\tset lib directory (default: lib)"
echo -e "-v, --version\tset TensorFlow version"
echo -e "-h, --help\tdisplay this message"
}
@ -26,7 +28,7 @@ usage() {
[ $# == 0 ] && usage && exit 0
# read the options
ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@")
ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@")
eval set -- "$ARGS"
# extract options and their arguments into variables.
@ -38,6 +40,11 @@ while true ; do
"") shift 2 ;;
*) TF_PREFIX=$2 ; shift 2 ;;
esac ;;
-l|--libdir)
case "$2" in
"") shift 2 ;;
*) LIBDIR=$2 ; shift 2 ;;
esac ;;
-v|--version)
case "$2" in
"") shift 2 ;;
@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX"
cat << EOF > tensorflow.pc
prefix=${TF_PREFIX}
exec_prefix=\${prefix}
libdir=\${exec_prefix}/lib
libdir=\${exec_prefix}/${LIBDIR}
includedir=\${prefix}/include
Name: TensorFlow

View File

@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) {
return "<Unknown AttrValue type>"; // Prevent missing return warning
}
bool IsEmptyList(const AttrValue::ListValue& list) {
return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
list.b_size() == 0 && list.type_size() == 0 &&
list.shape_size() == 0 && list.tensor_size() == 0;
}
string ToCamelCase(const string& str) {
string result;
const char joiner = '_';
@ -297,9 +303,9 @@ string ToCamelCase(const string& str) {
// indicate whether to treat the type as const when accepting the C++ type as an
// argument to a function.
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
StringPieceHasher>
attr_type_map{
static const auto* attr_type_map =
new std::unordered_map<StringPiece, std::pair<const char*, bool>,
StringPieceHasher>{
{"string", {"StringPiece", false}},
{"list(string)", {"gtl::ArraySlice<string>", true}},
{"int", {"int64", false}},
@ -317,14 +323,34 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
{"func", {"NameAttrList", true}},
};
auto entry = attr_type_map.find(attr_type);
if (entry == attr_type_map.end()) {
auto entry = attr_type_map->find(attr_type);
if (entry == attr_type_map->end()) {
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
return {"", false};
}
return entry->second;
}
const char* ListElementTypeName(StringPiece attr_type) {
static const auto* attr_list_type_map =
new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
{"list(string)", "string"},
{"list(int)", "int"},
{"list(float)", "float"},
{"list(bool)", "bool"},
{"list(type)", "DataType"},
{"list(shape)", "PartialTensorShape"},
{"list(tensor)", "TensorProto"},
};
auto entry = attr_list_type_map->find(attr_type);
if (entry == attr_list_type_map->end()) {
LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
return "";
}
return entry->second;
}
bool IsCPPKeyword(StringPiece name) {
static const std::unordered_set<StringPiece, StringPieceHasher>
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
string defaults_static_storage;
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const {
"_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
strings::StrAppend(
&struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
"_ = ",
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
";\n");
string field_initiliazer;
auto& default_value = api_def_attr.default_value();
if (default_value.value_case() == AttrValue::kList &&
!IsEmptyList(default_value.list())) {
// Non-empty lists need static storage for their defaults. Define a
// function with static local variable that stores the array.
strings::StrAppend(&defaults_static_storage, " static ",
attr_type_name, " Default_", api_def_attr.rename_to(),
"() {\n");
strings::StrAppend(
&defaults_static_storage, " static const ",
ListElementTypeName(attr.type()), " kStorage[] = ",
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
";\n");
strings::StrAppend(&defaults_static_storage, " return ",
attr_type_name, "(kStorage);\n }\n");
// Set the field_initializer to call the defined function.
strings::StrAppend(&field_initiliazer, "Default_",
api_def_attr.rename_to(), "()");
} else {
field_initiliazer =
PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
}
strings::StrAppend(&struct_fields, " ", attr_type_name, " ",
api_def_attr.rename_to(), "_ = ", field_initiliazer,
";\n");
}
if (struct_fields.empty()) {
@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const {
string struct_decl = MakeComment(attrs_comment, " ");
strings::StrAppend(&struct_decl, " struct Attrs {\n");
strings::StrAppend(&struct_decl, setters, struct_fields);
if (!defaults_static_storage.empty()) {
strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage);
}
strings::StrAppend(&struct_decl, " };\n");
return struct_decl;

View File

@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual");
REGISTER_NO_GRADIENT_OP("LogicalAnd");
REGISTER_NO_GRADIENT_OP("LogicalOr");
REGISTER_NO_GRADIENT_OP("LogicalNot");
REGISTER_NO_GRADIENT_OP("Floor");
// Conjugate helper function returns the conjugate of an Output if it
// is complex valued.

View File

@ -7,6 +7,10 @@ package(
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
# We disable some tfcompile tests in the open source build with the
# "manual" tag to avoid making our OSS users build LLVM twice
# (once for host and once for target).
test_suite(
name = "all_tests",
tags = ["manual"],

View File

@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@ -180,6 +181,7 @@ cc_library(
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops",
],
)
@ -312,6 +314,7 @@ cc_library(
":common",
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
@ -332,6 +335,19 @@ cc_library(
],
)
cc_library(
name = "xla_cluster_util",
srcs = ["xla_cluster_util.cc"],
hdrs = ["xla_cluster_util.h"],
deps = [
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
],
)
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
@ -408,6 +424,38 @@ tf_cc_test(
],
)
cc_library(
name = "xla_fusion_optimizer",
srcs = ["xla_fusion_optimizer.cc"],
hdrs = ["xla_fusion_optimizer.h"],
visibility = ["//visibility:public"],
deps = [
":common",
":union_find",
":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
],
)
tf_cuda_cc_test(
name = "xla_fusion_optimizer_test",
srcs = ["xla_fusion_optimizer_test.cc"],
deps = [
":common",
":xla_cluster_util",
":xla_fusion_optimizer",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",

View File

@ -1050,7 +1050,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
.WithAttr("_outside", "O1"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT, DT_FLOAT}, shape2.opts());
Node* h = Binary(ops::NodeOut(recv2, 0), e,
Node* h = Binary(ops::NodeOut(recv2, 1), e,
shape2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
@ -1075,7 +1075,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"D:o:0", "F:o:0"},
{"F:o:0", "D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors",
@ -1123,13 +1123,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT, DT_FLOAT}, b2.opts());
Node* g = Binary(e, ops::NodeOut(recv2, 1),
Node* g = Binary(e, ops::NodeOut(recv2, 0),
b2.opts()
.WithName("G")
.WithControlInputs({recv2, e})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* h = Binary(ops::NodeOut(recv2, 0), e,
Node* h = Binary(ops::NodeOut(recv2, 1), e,
b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@ -41,9 +42,6 @@ limitations under the License.
namespace tensorflow {
const char* const kXlaClusterAttr = "_XlaCluster";
const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
namespace {
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
@ -60,6 +58,14 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
return false;
}
}
// XLA does not offer guaranteed aliasing between the input and output of the
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
return false;
}
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
@ -165,16 +171,6 @@ bool IsCompilableCall(const NodeDef& call_def,
return true;
}
// Returns the DeviceType corresponding to 'device'.
Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
// Tests whether `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
@ -183,18 +179,11 @@ bool HasResourceInputOrOutput(const Node& node) {
DT_RESOURCE) != node.output_types().end();
}
struct NodeCompare {
bool operator()(const Node* a, const Node* b) const {
return a->id() < b->id();
}
};
using OrderedNodeSet = std::set<Node*, NodeCompare>;
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
// TODO(hpucha): Consider a black list instead of a white list as
// implemented below.
// TODO(hpucha): Remove this code since this functionality is subsumed by
// Grappler XlaFusionOptimizer.
bool IsXlaFusable(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
@ -364,7 +353,7 @@ Status FindCompilationCandidates(
for (Node* node : graph.op_nodes()) {
sorted_nodes.push_back(node);
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare());
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
for (Node* node : sorted_nodes) {
VLOG(2) << "Fuel: " << fuel;
@ -379,9 +368,13 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
DeviceToDeviceType(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
VLOG(2) << "Compilation rejected node: not compilable " << node->name()
<< ": " << node->type_string();
continue;
}
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(
@ -430,46 +423,6 @@ struct Cluster {
int representative = -1;
};
// Returns a string describing how an edge from src to dst would
// create a cycle.
string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src,
int dst) {
int32 max_path_size = graph.num_node_ids() + 1;
std::vector<int32> path(max_path_size);
int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data());
if (path_size == 0) {
return "";
}
auto node_name = [&cycles, &graph](int node_id) {
if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
return string("(null)");
}
auto* node = graph.FindNodeId(node_id);
if (node == nullptr) {
return string("(null)");
}
return node->name();
};
string description;
strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
node_name(dst), " would create a cycle.\n");
path.resize(path_size);
for (int32 node_id : path) {
string ascii_art;
if (node_id == dst) {
ascii_art = "+-> ";
} else if (node_id != src) {
ascii_art = "| ";
} else {
ascii_art = "+-- ";
}
strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
}
return description;
}
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
@ -575,84 +528,13 @@ Status MarkForCompilationPass::RunImpl(
: Env::Default(),
is_compilable_fn, &compilation_candidates));
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
return Status::OK();
}
GraphCycles cycles;
for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive
// integers starting from 0.
CHECK_EQ(i, cycles.NewNode());
}
// Compute the loop structure of the graph.
std::vector<ControlFlowInfo> control_flow_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
// The clustering code must avoid adding cycles to the graph to prevent
// deadlock. However, the graph may contain loops, which would trigger the
// cycle detection code. To handle loops, we alter the structure of the cycle
// detection graph, disconnecting each loop from the enclosing graph.
// Specifically, we:
// * add a new "frame" node for each loop.
// * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
// to/from the corresponding frame node. In essence, we collapse the loop
// into a single node for the purpose of cycle detection in the enclosing
// graph.
// * the body of the loop should now be disconnected from the rest of the
// graph; we make it acyclic by breaking loop backedges (edges outgoing from
// "NextIteration" nodes.
// Map from frame name strings to node IDs in the cycle detection graph.
std::unordered_map<string, int> frame_nodes;
// Get the cycle graph node ID for frame 'frame_name', or add one if none
// exists.
auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
if (frame_id < 0) {
// The emplace succeeded; we have not allocated a frame node yet.
frame_id = cycles.NewNode();
}
return frame_id;
};
for (Edge const* edge : graph->edges()) {
if (edge->dst()->IsEnter()) {
// Lift edges to an "Enter" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->dst()->id()].frame_name;
int dst = GetOrAddFrameNodeId(frame_name);
if (!cycles.InsertEdge(edge->src()->id(), dst)) {
return errors::Internal(
"Cycle detected when adding enter->frame edge: ",
DescribeCycle(cycles, *graph, edge->src()->id(), dst));
}
continue;
}
if (edge->src()->IsExit()) {
// Lift edges from an "Exit" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->src()->id()].frame_name;
int src = GetOrAddFrameNodeId(frame_name);
if (!cycles.InsertEdge(src, edge->dst()->id())) {
return errors::Internal(
"Cycle detected when adding frame->exit edge: ",
DescribeCycle(cycles, *graph, src, edge->dst()->id()));
}
// Drop the original edge.
continue;
}
if (edge->src()->IsNextIteration()) {
// Break loop back-edges.
continue;
}
if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
// This should never happen. All cycles in the graph should contain
// a control flow operator.
return errors::Internal(
"Found cycle in graph without control flow operator during XLA "
"compilation: ",
DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
}
}
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
// Each compilation candidate belongs to a cluster. The cluster's
// representative
@ -670,6 +552,9 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
//
// TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
// example, from the Grappler fusion pass).
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
@ -778,7 +663,7 @@ Status MarkForCompilationPass::RunImpl(
// compilation.
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
DeviceToDeviceType(n->assigned_device_name(), &device_type));
const XlaOpRegistry::DeviceRegistration* registration;
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);

View File

@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) {
}
}
TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
Scope root = Scope::NewRootScope().ExitOnError();
Output variable = ops::Variable(root.WithOpName("variable"),
PartialTensorShape{}, DT_FLOAT);
Output read = ops::Identity(root.WithOpName("read"), variable);
Output neg = ops::Negate(root.WithOpName("negate"), read);
Output add = ops::Add(root.WithOpName("add"), neg, neg);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
ASSERT_FALSE(clusters.empty());
string cluster_name = clusters.begin()->second;
std::unordered_map<string, string> expected_clusters(
{{"negate", cluster_name}, {"add", cluster_name}});
EXPECT_EQ(clusters, expected_clusters);
}
TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
Scope root = Scope::NewRootScope().ExitOnError();
Output variable = ops::Variable(root.WithOpName("variable"),
PartialTensorShape{}, DT_FLOAT);
Output read = ops::Identity(root.WithOpName("read"), variable);
Output neg = ops::Negate(root.WithOpName("negate"), read);
Output identity = ops::Negate(root.WithOpName("identity"), neg);
Output add = ops::Add(root.WithOpName("add"), identity, neg);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
ASSERT_FALSE(clusters.empty());
string cluster_name = clusters.begin()->second;
std::unordered_map<string, string> expected_clusters(
{{"negate", cluster_name},
{"identity", cluster_name},
{"add", cluster_name}});
EXPECT_EQ(clusters, expected_clusters);
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,183 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include <unordered_map>
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
const char* const kXlaClusterAttr = "_XlaCluster";
const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
namespace {
// Returns a string describing how an edge from src to dst would
// create a cycle.
string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
int dst) {
int32 max_path_size = graph.num_node_ids() + 1;
std::vector<int32> path(max_path_size);
int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
if (path_size == 0) {
return "";
}
auto node_name = [cycles, &graph](int node_id) {
if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
return string("(null)");
}
auto* node = graph.FindNodeId(node_id);
if (node == nullptr) {
return string("(null)");
}
return node->name();
};
string description;
strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
node_name(dst), " would create a cycle.\n");
path.resize(path_size);
for (int32 node_id : path) {
string ascii_art;
if (node_id == dst) {
ascii_art = "+-> ";
} else if (node_id != src) {
ascii_art = "| ";
} else {
ascii_art = "+-- ";
}
strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
}
return description;
}
bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
} // namespace
Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
return errors::Internal("Malformed assigned device '", device, "'");
}
*device_type = DeviceType(parsed.type);
return Status::OK();
}
bool HasForwardedRefInput(const Node& node) {
if (AlwaysForwardsRefInput(node)) {
for (const Edge* incoming_edge : node.in_edges()) {
if (incoming_edge->IsControlEdge()) {
continue;
}
Node* incoming_node = incoming_edge->src();
if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
<< incoming_node->name() << " " << incoming_node->type_string();
return true;
}
}
}
return false;
}
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
for (int i = 0; i < graph->num_node_ids(); ++i) {
// We rely on the node IDs in the cycle detection graph being consecutive
// integers starting from 0.
CHECK_EQ(i, cycles->NewNode());
}
// Compute the loop structure of the graph.
std::vector<ControlFlowInfo> control_flow_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
// The clustering code must avoid adding cycles to the graph to prevent
// deadlock. However, the graph may contain loops, which would trigger the
// cycle detection code. To handle loops, we alter the structure of the cycle
// detection graph, disconnecting each loop from the enclosing graph.
// Specifically, we:
// * add a new "frame" node for each loop.
// * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
// to/from the corresponding frame node. In essence, we collapse the loop
// into a single node for the purpose of cycle detection in the enclosing
// graph.
// * the body of the loop should now be disconnected from the rest of the
// graph; we make it acyclic by breaking loop backedges (edges outgoing from
// "NextIteration" nodes.
// Map from frame name strings to node IDs in the cycle detection graph.
std::unordered_map<string, int> frame_nodes;
// Get the cycle graph node ID for frame 'frame_name', or add one if none
// exists.
auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
if (frame_id < 0) {
// The emplace succeeded; we have not allocated a frame node yet.
frame_id = cycles->NewNode();
}
return frame_id;
};
for (Edge const* edge : graph->edges()) {
if (edge->dst()->IsEnter()) {
// Lift edges to an "Enter" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->dst()->id()].frame_name;
int dst = GetOrAddFrameNodeId(frame_name);
if (!cycles->InsertEdge(edge->src()->id(), dst)) {
return errors::Internal(
"Cycle detected when adding enter->frame edge: ",
DescribeCycle(cycles, *graph, edge->src()->id(), dst));
}
continue;
}
if (edge->src()->IsExit()) {
// Lift edges from an "Exit" node to the corresponding frame node.
const string& frame_name =
control_flow_info[edge->src()->id()].frame_name;
int src = GetOrAddFrameNodeId(frame_name);
if (!cycles->InsertEdge(src, edge->dst()->id())) {
return errors::Internal(
"Cycle detected when adding frame->exit edge: ",
DescribeCycle(cycles, *graph, src, edge->dst()->id()));
}
// Drop the original edge.
continue;
}
if (edge->src()->IsNextIteration()) {
// Break loop back-edges.
continue;
}
if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
// This should never happen. All cycles in the graph should contain
// a control flow operator.
return errors::Internal(
"Found cycle in graph without control flow operator during XLA "
"compilation: ",
DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Contains utilities for clustering compilable graph nodes via XLA.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
namespace tensorflow {
// The attribute that marks nodes to be grouped into functions by the
// encapsulate subgraphs pass.
extern const char* const kXlaClusterAttr;
// The attribute that marks nodes in a cluster to be placed outside the xla
// compilation by the encapsulate subgraphs pass.
extern const char* const kXlaOutsideCompilationAttr;
using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
// Returns the DeviceType corresponding to 'device'.
Status DeviceToDeviceType(const string& device, DeviceType* device_type);
// Returns true if `node` has a ref tensor input that it forwards to its output.
bool HasForwardedRefInput(const Node& node);
// Creates a graph representation to enable cycle detection when clustering.
// This representation handles loops in graph by disconnecting each loop from
// the enclosing graph.
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@ -87,6 +88,46 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
REGISTER_KERNEL_BUILDER(Name("Shape") \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<int32>("out_type") \
.TypeConstraint("T", TYPES), \
ShapeOp<int32>); \
REGISTER_KERNEL_BUILDER(Name("Shape") \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<int64>("out_type") \
.TypeConstraint("T", TYPES), \
ShapeOp<int64>); \
REGISTER_KERNEL_BUILDER(Name("ShapeN") \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<int32>("out_type") \
.TypeConstraint("T", TYPES), \
ShapeNOp<int32>); \
REGISTER_KERNEL_BUILDER(Name("ShapeN") \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<int64>("out_type") \
.TypeConstraint("T", TYPES), \
ShapeNOp<int64>); \
REGISTER_KERNEL_BUILDER(Name("Size") \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<int32>("out_type") \
.TypeConstraint("T", TYPES), \
SizeOp<int32>); \
REGISTER_KERNEL_BUILDER(Name("Size") \
.Device(DEVICE) \
.HostMemory("output") \
.TypeConstraint<int64>("out_type") \
.TypeConstraint("T", TYPES), \
SizeOp<int64>); \
REGISTER_KERNEL_BUILDER( \
Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \
TYPES), \
RankOp); \
REGISTER_KERNEL_BUILDER( \
Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
XlaAssignVariableOp); \
@ -95,7 +136,16 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
SwitchOp); \
REGISTER_KERNEL_BUILDER( \
Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp);
Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
NextIterationOp); \
REGISTER_KERNEL_BUILDER(Name("LoopCond") \
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
LoopCondOp);
} // namespace tensorflow

View File

@ -0,0 +1,328 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include <atomic>
#include <deque>
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
namespace tensorflow {
// Is 'node' an operator that consumes only the shape of its input, not the
// data itself?
static bool IsShapeConsumerOp(const Node& node) {
return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
node.type_string() == "Rank" || node.type_string() == "Size";
}
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
bool IsXlaFusable(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
{// tf2xla/kernels/aggregate_ops.cc
"AddN",
// tf2xla/kernels/binary_ops.cc
"Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
"FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
"LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
"ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
"TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
"GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
"SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
// tf2xla/kernels/unary_ops.cc
"ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
"Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
"Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
"Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
"Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
"Square", "Tan", "Tanh", "Real", "Imag",
// tf2xla/kernels/bcast_ops.cc
"BroadcastArgs", "BroadcastGradientArgs",
// tf2xla/kernels/bias_ops.cc
"BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
// tf2xla/kernels/cast_op.cc
"Cast",
// tf2xla/kernels/concat_op.cc
"Concat", "ConcatV2", "ConcatOffset",
// tf2xla/kernels/const_op.cc
"Const",
// tf2xla/kernels/elu_op.cc
"Elu", "EluGrad", "Selu", "SeluGrad",
// tf2xla/kernels/fill_op.cc
"Fill",
// tf2xla/kernels/identity_op.cc
"Identity", "IdentityN", "PreventGradient",
"StopGradient", /*"Snapshot",*/
// tf2xla/kernels/index_ops.cc
"ArgMax", "ArgMin",
// tf2xla/kernels/mirror_pad_op.cc
"MirrorPad",
// tf2xla/kernels/one_hot_op.cc
"OneHot",
// tf2xla/kernels/pack_op.cc
"Pack",
// tf2xla/kernels/pad_op.cc
"Pad", "PadV2",
// tf2xla/kernels/relu_op.cc
"Relu", "Relu6", "ReluGrad", "Relu6Grad",
// tf2xla/kernels/reshape_op.cc
"Reshape",
// tf2xla/kernels/reverse_op.cc
"Reverse", "ReverseV2",
// tf2xla/kernels/reverse_sequence_op.cc
"ReverseSequence",
// tf2xla/kernels/shape_op.cc
"Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
"ZerosLike", "OnesLike",
// tf2xla/kernels/slice_op.cc
"Slice",
// tf2xla/kernels/split_op.cc
"Split", "SplitV",
// tf2xla/kernels/strided_slice_op.cc
"StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
// tf2xla/kernels/tile_ops.cc
"Tile",
// tf2xla/kernels/transpose_op.cc
"Transpose", "InvertPermutation",
// tf2xla/kernels/unpack_op.cc
"Unpack"});
return elementwise_ops->count(node.op()) > 0;
}
Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) {
VLOG(2) << "Here at fusion optimizer";
// TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
// Once that happens, the expected interaction between this optimizer and when
// the global_jit_level is set is as follows: Fusion optimizer will replace
// appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
// be further compiled where possible via mark_for_compilation_pass. Note that
// this might lead to inefficient clustering, and it is best to use either the
// fusion optimizer or the global_jit flag, and not combine the two.
// Create a Graph out of GraphDef. This is required currently because the
// helpers around clustering, encapsulation etc work on graphs.
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item.graph.library());
Graph graph(function_library);
ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(true);
ImportGraphDefOptions options;
// Graph optimization happens at the late stage of graph execution, when
// colocation constraints are already validated previously and the device
// placement of nodes has also completed, so there is no need to validate
// colocation constraints again.
options.validate_colocation_constraints = false;
options.validate_shape = false;
TF_RETURN_IF_ERROR(
ImportGraphDef(options, item.graph, &graph, &shape_refiner));
// Collect nodes that can be fused via XLA, while ignoring those that
// explicitly ask for XLA: (*) nodes that are marked to be compiled
// explicitly. (*) nodes assigned to XLA device.
OrderedNodeSet compilation_candidates;
for (Node* node : graph.op_nodes()) {
// If there is a _XlaCompile annotation, ignore the node if it is
// true. Nodes are marked with this attr via experimental_jit_scope, and
// will be handled by the mark_for_compilation pass.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
if (status.ok() && compile) {
continue;
}
// If there is already a _XlaCluster annotation, ignore the node. Nodes are
// marked with this attr to indicate they are already part of a cluster and
// hence ignored.
status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
if (status.ok()) {
continue;
}
// If there is an explicit XLA device placement, ignore the node.
DeviceType device_type("");
TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
if (device_type.type_string().find("XLA") != string::npos) continue;
// Assume all fusable ops are registered.
// TODO(hpucha): Check for registration if possible.
if (!IsXlaFusable(node->def())) {
continue;
}
// XLA does not offer guaranteed aliasing between the input and output of
// the XLA cluster so it can't implement the forward-tensor-ref semantic.
// Leave such nodes out of XLA clusters.
if (HasForwardedRefInput(*node)) {
continue;
}
compilation_candidates.insert(node);
}
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
*output = item.graph;
return Status::OK();
}
GraphCycles cycles;
TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
// TODO(hpucha): Make clustering more robust. There are two known issues that
// we need to mitigate: (a) Non-resource variables can cause deadlocks
// when clustering changes order of execution. See b/77263461 for a specific
// example. (b) Queue operations can also cause deadlocks. See b/77261498 for
// example.
struct Cluster {
// Identifies the node that represents this cluster in the cycle detection
// graph.
int representative = -1;
};
// Each compilation candidate belongs to a cluster. The cluster's
// representative names the node in the 'cycles' graph that represents the
// cluster.
std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
std::deque<UnionFind<Cluster>*> worklist;
for (Node* node : compilation_candidates) {
Cluster& cluster = clusters[node->id()].Get();
cluster.representative = node->id();
worklist.push_back(&clusters[node->id()]);
}
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle. This is a simplified
// version of the clustering in mark_for_compilation_pass that also deals with
// nodes that are explicitly tagged to be compiled/clustered.
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
Node* node_from = graph.FindNodeId(from);
if (node_from->IsControlFlow()) {
// Control flow nodes aren't compilation candidates and should never
// appear.
return errors::Internal(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
for (int to : cycles.Successors(from)) {
if (to >= graph.num_node_ids()) {
// Node is a "frame" node that is present only in the cycle detection
// graph. No clustering is possible.
continue;
}
Node* node_to = graph.FindNodeId(to);
if (compilation_candidates.find(node_to) ==
compilation_candidates.cend()) {
continue;
}
// Do not cluster across devices.
if (node_from->def().device() != node_to->def().device()) {
VLOG(2) << "Devices " << node_from->def().device() << " "
<< node_to->def().device();
VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
<< node_to->assigned_device_name();
continue;
}
// Ops that consume shapes cannot be the root of a cluster. This is an
// optimization.
if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
continue;
}
// If contracting the edge would create a cycle, bail out.
// However, just because we can't merge the clusters now does not mean
// we won't be able to merge them in the future.
// e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
// 1->3. But if we first contract 1->2 then we can later contract 1->3.
if (!cycles.ContractEdge(from, to)) continue;
// Merge the clusters. ContractEdge uses 'from' as the number of the
// merged node, so make sure 'from' is the chosen representative.
clusters[from].Merge(&clusters[to]);
worklist.push_back(&clusters[from]);
break;
}
}
// Count the number of non-trivial elements in each cluster.
std::vector<int> effective_cluster_sizes(graph.num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Identity nodes will be removed if the node gets marked for compilation.
// Therefore we don't want to count them towards the effective cluster size.
if (n->def().op() != "Identity") {
effective_cluster_sizes[cluster]++;
}
}
const int min_cluster_size = 2;
int num_clusters = 0;
for (auto size : effective_cluster_sizes) {
if (size >= min_cluster_size) {
VLOG(3) << "Cluster " << num_clusters << " " << size;
num_clusters++;
}
}
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
// Sequence number generator to ensure clusters have unique names.
static std::atomic<int64> cluster_sequence_num;
for (Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
// Compile if this is a cluster of >= min_cluster_size compilable operators.
if (effective_cluster_sizes[cluster] >= min_cluster_size) {
string& name = cluster_names[cluster];
if (name.empty()) {
name = strings::StrCat("cluster_", cluster_sequence_num++);
}
n->AddAttr(kXlaClusterAttr, name);
VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
}
}
graph.ToGraphDef(output);
return Status::OK();
}
REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
// Optimizes graphs by fusing ops where possible, resulting in more efficient
// execution.
class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
public:
XlaFusionOptimizer() {}
~XlaFusionOptimizer() override {}
Status Init(
const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
return Status::OK();
}
string name() const override { return "xla-fusion"; };
Status Optimize(grappler::Cluster* cluster,
const grappler::GrapplerItem& item,
GraphDef* output) override;
void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
const GraphDef& optimize_output, double result) override {
// Nothing to do for XlaFusionOptimizer.
}
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_

View File

@ -0,0 +1,183 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
class XlaFusionOptimizerTest : public grappler::GrapplerTest {
protected:
std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
std::unordered_map<string, string> ids;
for (const NodeDef& node : graph.node()) {
string cluster;
if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
CHECK(!cluster.empty());
ids[node.name()] = cluster;
}
}
return ids;
}
};
TEST_F(XlaFusionOptimizerTest, Chains) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a =
ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
Node* d =
ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
EXPECT_EQ(clusters["E"], clusters["F"]);
EXPECT_NE(clusters["B"], clusters["E"]);
EXPECT_TRUE(clusters.find("A") == clusters.cend());
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, FusableOps) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(2, clusters.size());
EXPECT_EQ(clusters["C"], clusters["E"]);
EXPECT_TRUE(clusters.find("D") == clusters.cend());
}
TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp(
"Placeholder",
builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* b = ops::SourceOp(
"Placeholder",
builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
Node* c = ops::BinaryOp(
"Add", a, b,
builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
ops::UnaryOp("Cos", e,
builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b =
ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_TRUE(clusters.empty());
}
TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
GraphDef graph;
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
TF_ASSERT_OK(builder.ToGraphDef(&graph));
}
grappler::GrapplerItem item;
item.graph = graph;
XlaFusionOptimizer optimizer;
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
auto clusters = GetClusters(output);
EXPECT_EQ(3, clusters.size());
EXPECT_EQ(clusters["A"], clusters["B"]);
EXPECT_EQ(clusters["A"], clusters["C"]);
}
} // namespace
} // namespace tensorflow

View File

@ -120,6 +120,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "bucketize_op_test",
size = "small",
srcs = ["bucketize_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "categorical_op_test",
size = "small",
@ -532,7 +545,9 @@ tf_xla_py_test(
],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],

View File

@ -0,0 +1,78 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bucketize_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class BucketizationOpTest(XLATestCase):
def testInt(self):
with self.test_session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4]
self.assertAllEqual(expected_out,
sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]}))
def testFloat(self):
with self.test_session() as sess:
p = array_ops.placeholder(dtypes.float32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.])
expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4]
self.assertAllEqual(
expected_out,
sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]}))
def test2DInput(self):
with self.test_session() as sess:
p = array_ops.placeholder(dtypes.float32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11])
expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]]
self.assertAllEqual(
expected_out, sess.run(op,
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
def testInvalidBoundariesOrder(self):
with self.test_session() as sess:
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11])
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Expected sorted boundaries"):
sess.run(op, {p: [-5, 0]})
def testBoundariesNotList(self):
with self.test_session():
with self.assertRaisesRegexp(TypeError, "Expected list.*"):
p = array_ops.placeholder(dtypes.int32)
with self.test_scope():
math_ops._bucketize(p, boundaries=0)
if __name__ == "__main__":
test.main()

View File

@ -160,6 +160,77 @@ class EagerTest(XLATestCase):
for _ in range(100):
values.append(var.value())
# The shape, shape_n, size, and rank are tested here because their
# execution kernels (as opposed to compilation only tf2xla kernels)
# are distincts from tf2xla kernels.
def testShape(self):
def const(value):
return array_ops.shape(
constant_op.constant(value)).numpy()
def ones(value):
return array_ops.shape(
array_ops.ones(value)).numpy()
with self.test_scope():
# Shapes of directly constructed tensors
self.assertAllEqual([], const(3))
self.assertAllEqual([3], const([1.0, 2.0, 3.0]))
self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]]))
self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]]))
# Shapes of tensors created by op running on device
# We make this distinction because directly constructed tensors
# are treated differently in a few places that can influence shape:
# - they always have on_host_tensor
# - they and their shapes can be cached
# - they end up on device via a copy, instead of as program output
self.assertAllEqual([], ones([]))
self.assertAllEqual([3], ones([3]))
self.assertAllEqual([2, 2], ones([2, 2]))
self.assertAllEqual([2, 1, 2], ones([2, 1, 2]))
def testShapeN(self):
with self.test_scope():
# Shapes of directly constructed tensors
shapes = array_ops.shape_n([
constant_op.constant(1.0),
constant_op.constant([1.0, 2.0, 3.0]),
constant_op.constant([[1.0, 2.0], [3.0, 4.0]])])
self.assertAllEqual(
[[], [3], [2, 2]],
[x.numpy().tolist() for x in shapes])
# Shapes of tensors created by op running on device
shapes = array_ops.shape_n([
array_ops.ones([]),
array_ops.ones([3]),
array_ops.ones([2, 2])])
self.assertAllEqual(
[[], [3], [2, 2]],
[x.numpy().tolist() for x in shapes])
def testSize(self):
with self.test_scope():
self.assertEqual(
1, array_ops.size(constant_op.constant(1.0)).numpy())
self.assertEqual(
3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy())
self.assertEqual(
4, array_ops.size(
constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
def testRank(self):
with self.test_scope():
self.assertEqual(
0, array_ops.rank(constant_op.constant(1.0)).numpy())
self.assertEqual(
1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy())
self.assertEqual(
2, array_ops.rank(
constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
class EagerFunctionTest(XLATestCase):

View File

@ -65,9 +65,7 @@ class RGBToHSVTest(XLATestCase):
join1 = array_ops.stack(split1)
join2 = array_ops.stack(split2)
batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2],
{
batch0: inp
})
{batch0: inp})
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1)
@ -401,9 +399,7 @@ class AdjustSaturationTest(XLATestCase):
x = array_ops.placeholder(dtypes.float32, shape=x_shape)
with self.test_scope():
y_fused = self._adjust_saturation(x,
scale).eval(feed_dict={
x: x_np
})
scale).eval(feed_dict={x: x_np})
self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5)
@ -412,7 +408,8 @@ class ResizeBilinearTest(XLATestCase):
def _assertForwardOpMatchesExpected(self,
image_np,
target_shape,
expected=None):
expected=None,
large_tolerance=False):
if expected is None:
self.fail("expected must be specified")
with self.test_session() as sess, self.test_scope():
@ -420,7 +417,11 @@ class ResizeBilinearTest(XLATestCase):
resized = gen_image_ops.resize_bilinear(
image, target_shape, align_corners=True)
out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]})
self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out)
if large_tolerance:
self.assertAllClose(
expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1)
else:
self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out)
def _assertBackwardOpMatchesExpected(self,
grads_np,
@ -555,6 +556,28 @@ class ResizeBilinearTest(XLATestCase):
[[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]],
dtype=np.float32))
def testAlignCorners4x4To8x8(self):
self._assertForwardOpMatchesExpected(
(np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array(
[[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8],
expected=3 *
(np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array(
[[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)),
large_tolerance=True)
def testAlignCorners8x8To16x16(self):
self._assertForwardOpMatchesExpected(
(np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array(
[[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0,
[16, 16],
expected=7 * (np.array(
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],
dtype=np.float32) + np.array(
[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
[12], [13], [14], [15]],
dtype=np.float32)),
large_tolerance=True)
if __name__ == "__main__":
test.main()

View File

@ -22,6 +22,8 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import googletest
@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase):
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
(not np.array_equal(z, w)) or
(not np.array_equal(y, w)))
(not np.array_equal(z, w)) or (not np.array_equal(y, w)))
def testRandomUniformIsNotConstant(self):
def rng(dtype):
return random_ops.random_uniform(shape=[2], dtype=dtype,
maxval=1000000)
return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalIsNotConstant(self):
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
@ -70,12 +72,20 @@ class RandomOpsTest(XLATestCase):
for dtype in self._random_types():
with self.test_session() as sess:
with self.test_scope():
x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
maxval=33)
x = random_ops.random_uniform(
shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = sess.run(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
def testTruncatedNormalIsNotConstant(self):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
# TODO(b/34339814): implement inverse erf support for non-F32 types.
self._testRngIsNotConstant(rng, dtypes.float32)
def testTruncatedNormalIsInRange(self):
count = 10000
# TODO(b/34339814): implement inverse erf support for non-F32 types.
@ -87,6 +97,29 @@ class RandomOpsTest(XLATestCase):
self.assertTrue((y >= -2).sum() == count)
self.assertTrue((y <= 2).sum() == count)
def testShuffle1d(self):
with self.test_session() as sess:
with self.test_scope():
x = math_ops.range(20)
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
expected = range(20)
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
self.assertAllEqual(set(result), set(expected))
def testShuffle2d(self):
with self.test_session() as sess:
with self.test_scope():
x = array_ops.diag(math_ops.range(20))
shuffle = random_ops.random_shuffle(x)
result = sess.run(shuffle)
expected = np.diag(range(20)).flatten()
# Compare sets to avoid randomness behavior changes but make sure still
# have all the values.
self.assertAllEqual(len(result.flatten()), len(expected))
self.assertAllEqual(set(result.flatten()), set(expected))
if __name__ == '__main__':
googletest.main()

View File

@ -1438,7 +1438,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
std::vector<ControlFlowInfo> cf_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
std::vector<string> unreachable_nodes;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes));
if (!unreachable_nodes.empty()) {
return errors::InvalidArgument(
"The following nodes are unreachable from the source in the graph: ",
tensorflow::str_util::Join(unreachable_nodes, ", "));
}
// Builds Frames, indexed by name.
std::unordered_map<string, Frame> frames;

View File

@ -18,6 +18,7 @@ tf_kernel_library(
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
"bucketize_op.cc",
"cast_op.cc",
"categorical_op.cc",
"cholesky_op.cc",

View File

@ -0,0 +1,67 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
class BucketizeOp : public XlaOpKernel {
public:
explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_));
OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()),
errors::InvalidArgument("Expected sorted boundaries"));
}
void Compile(XlaOpKernelContext* context) override {
xla::XlaBuilder* builder = context->builder();
const DataType dtype = context->input_type(0);
xla::XlaOp input = context->Input(0);
xla::XlaOp boundaries = builder->ConstantR1<float>(boundaries_);
// TODO(phawkins): the following behavior matches the behavior of the core
// Bucketize kernel. However, comparing an int32 or int64 against float may
// lead to inaccurate bucketing due to rounding.
if (dtype == DT_DOUBLE) {
input = builder->ConvertElementType(input, xla::F64);
boundaries = builder->ConvertElementType(boundaries, xla::F64);
} else {
input = builder->ConvertElementType(input, xla::F32);
}
xla::XlaOp comparison = builder->ConvertElementType(
builder->Ge(builder->Broadcast(input, {1}), boundaries,
/*broadcast_dimensions=*/{0}),
xla::S32);
xla::XlaOp buckets = builder->Reduce(
comparison, /*init_value=*/builder->ConstantR0<int32>(0),
/*computation=*/xla::CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
context->SetOutput(0, buckets);
}
private:
std::vector<float> boundaries_;
};
REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp);
} // namespace
} // namespace tensorflow

View File

@ -48,11 +48,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Building If: " << input_types_.size() << " inputs";
std::vector<xla::XlaOp> inputs(input_types_.size());
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
DataType type = ctx->input_type(i + 1);
if (type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
@ -60,7 +60,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arg.initialized = resource->initialized();
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind();
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
arg.type = resource->type();
arg.shape = resource->shape();
@ -79,7 +78,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_types_[i];
arg.shape = ctx->InputShape(i + 1);
inputs[i] = ctx->Input(i + 1);
VLOG(2) << "Arg type: " << DataTypeString(arg.type)
<< " shape: " << arg.shape.DebugString();
}
@ -100,6 +98,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
arguments, &else_result));
bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
XlaResource* resource;
@ -121,9 +120,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
if (!resource->tensor_array_gradients().empty())
has_tensor_array_gradients = true;
}
}
// Recompile the functions to update the argument shapes for tensor arrays.
if (has_tensor_array_gradients) {
then_result = {};
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
arguments, &then_result));
else_result = {};
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
arguments, &else_result));
}
// Check that both branches have identical input shapes.
OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
@ -175,6 +186,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
"Mismatch in resource of then and else branch for resource ", i));
}
int num_inputs = then_result.input_mapping.size();
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
} else {
inputs[i] = ctx->Input(i + 1);
}
}
xla::XlaOp outputs =
b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
b->Tuple(inputs), *else_result.computation);

View File

@ -99,27 +99,34 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
return dims;
}
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
// 1/9 * 3 6 9 6 3
// 2 4 6 4 2
// 1 2 3 2 1
// by multiplying two 1D kernels of the form:
// 1/3 * [1 2 3 2 1]
// If the 2D kernel would be very large, the 1D kernel can be applied once in
// each dimension due to the symmetry of the kernel along all axis to reduce the
// computational intensity.
std::vector<float> Make1DKernel(int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
return kernel;
}
// Kernels with more than 16 spatial elements are considered intense and the
// kernel should applied to each dimension independently.
const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
gtl::ArraySlice<int64> kernel_size,
int64 channels) {
// Form a 2D convolution kernel like:
// 1 2 3 2 1
// 2 4 6 4 2
// 1/9 * 3 6 9 6 3
// 2 4 6 4 2
// 1 2 3 2 1
// by multiplying two 1D kernels of the form:
// 1/3 * [1 2 3 2 1]
auto make_1d_kernel = [](int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
return kernel;
};
xla::XlaOp channels_iota;
// DT_INT32 Iota will always return status::OK().
TF_CHECK_OK(
@ -133,12 +140,37 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
xla::PrimitiveType::F32);
return builder->Mul(
builder->Mul(diag,
builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])),
builder->ConstantR1<float>(Make1DKernel(kernel_size[1])),
/*broadcast_dimensions=*/{1}),
builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])),
builder->ConstantR1<float>(Make1DKernel(kernel_size[0])),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
gtl::ArraySlice<int64> kernel_size,
int64 channels, int64 dim) {
xla::XlaOp channels_iota;
// DT_INT32 Iota will always return status::OK().
TF_CHECK_OK(
XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
auto diag = builder->ConvertElementType(
builder->Eq(builder->Broadcast(
channels_iota,
{dim == 0 ? (2 * kernel_size[0] - 1) : 1,
dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
channels_iota, /*broadcast_dimensions=*/{2}),
xla::PrimitiveType::F32);
if (dim == 1) {
return builder->Mul(
diag, builder->ConstantR1<float>(Make1DKernel(kernel_size[1])),
/*broadcast_dimensions=*/{1});
}
return builder->Mul(diag,
builder->ConstantR1<float>(Make1DKernel(kernel_size[0])),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
const xla::XlaOp& input,
const int num_spatial_dims,
@ -165,20 +197,42 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
dimension_numbers.add_output_spatial_dimensions(1 + i);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
ResizeConvolutionDims dims =
ComputeResizeConvolutionParameters(in_size, out_size);
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
xla::XlaOp output = builder->ConvGeneralDilated(
input, kernel, dims.stride,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.kernel_size,
/*rhs_dilation=*/{1, 1}, dimension_numbers);
xla::XlaOp output;
// Split convolutions into independent dimensions if they wmuld be a very
// large kernel.
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
output = builder->ConvGeneralDilated(
input, kernel, dims.stride,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.kernel_size,
/*rhs_dilation=*/{1, 1}, dimension_numbers);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
output = builder->ConvGeneralDilated(
input, kernel0, {dims.stride[0], 1},
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
output = builder->ConvGeneralDilated(
output, kernel1, {1, dims.stride[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
}
// Add broadcasts to handle expanding from a size == 1 dimension to a
// size > 1 dimension.
@ -214,26 +268,63 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
}
dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
xla::XlaOp output;
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
// Broadcast the input kernel where the forward op expanded from a size == 1
// dimension to a size > 1 dimension. This has the effect of summing the
// gradient contributions in that dimension.
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] == 1 && grad_size[i] > 1) {
kernel = builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0),
/*broadcast_dimensions=*/{i});
// Broadcast the input kernel where the forward op expanded from a size == 1
// dimension to a size > 1 dimension. This has the effect of summing the
// gradient contributions in that dimension.
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] == 1 && grad_size[i] > 1) {
kernel =
builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0),
/*broadcast_dimensions=*/{i});
}
}
}
xla::XlaOp output = builder->ConvGeneralDilated(
grad, kernel, /*window_strides=*/dims.kernel_size,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
/*rhs_dilation=*/{1, 1}, dimension_numbers);
output = builder->ConvGeneralDilated(
grad, kernel, /*window_strides=*/dims.kernel_size,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
/*rhs_dilation=*/{1, 1}, dimension_numbers);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
// Broadcast the input kernel where the forward op expanded from a size == 1
// dimension to a size > 1 dimension. This has the effect of summing the
// gradient contributions in that dimension.
if (in_size[0] == 1 && grad_size[0] > 1) {
kernel0 =
builder->Add(kernel0, builder->ConstantR1<float>(grad_size[0], 0),
/*broadcast_dimensions=*/{0});
}
if (in_size[1] == 1 && grad_size[1] > 1) {
kernel1 =
builder->Add(kernel0, builder->ConstantR1<float>(grad_size[1], 0),
/*broadcast_dimensions=*/{1});
}
output = builder->ConvGeneralDilated(
grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
output = builder->ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.stride[1]},
/*rhs_dilation=*/{1, 1}, dimension_numbers);
}
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
// Opposite of the slice performed by the forward op.

View File

@ -17,6 +17,9 @@ limitations under the License.
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
// TODO(misard,phawkins): add tests.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@ -55,6 +58,78 @@ class RandomUniformOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
RandomUniformOp);
class RandomShuffleOp : public XlaOpKernel {
public:
explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
auto builder = ctx->builder();
xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);
const int64 n = input_shape.dim_size(0);
int64 num_elements = 1;
for (tensorflow::TensorShapeDim dimension : input_shape) {
num_elements *= dimension.size;
}
if (num_elements <= 1 || n <= 1) {
// No shuffling is required, so copy input directly to output
ctx->SetOutput(0, input);
} else {
// Generate the random swaps for the indices.
auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
auto swaps =
builder->RngUniform(builder->ConstantR0<int32>(0),
builder->ConstantR0<int32>(n), swaps_shape);
// Generate range(n) as the initial value for the indices to be swapped.
xla::XlaOp indices;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices));
// Swap the indices at i and swaps[i].
auto swap_body_fn = [&](xla::XlaOp i,
gtl::ArraySlice<xla::XlaOp> loop_vars,
xla::XlaBuilder* builder)
-> xla::StatusOr<std::vector<xla::XlaOp>> {
auto swaps = loop_vars[0];
auto indices = loop_vars[1];
i = builder->Reshape(i, {1});
// temp = indices[i]
auto temp = builder->DynamicSlice(indices, i, {1});
// swap_index = swaps[i]
auto swap_index = builder->DynamicSlice(swaps, i, {1});
// swap_value = indices[swaps[i]]
auto swap_value = builder->DynamicSlice(indices, swap_index, {1});
// indices[i] = indices[swaps[i]]
indices = builder->DynamicUpdateSlice(indices, swap_value, i);
// indices[swaps[i]] = temp
indices = builder->DynamicUpdateSlice(indices, temp, swap_index);
return std::vector<xla::XlaOp>{swaps, indices};
};
// for i in range(n):
auto swap_loop_result =
XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
"indices_swap_loop", builder)
.ValueOrDie();
auto swapped_indices = swap_loop_result[1];
// Gather the data using the swapped indices as the shuffled order.
auto indices_tensor_shape = TensorShape({n});
DataType type = ctx->expected_output_dtype(0);
xla::XlaOp gather;
OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
indices_tensor_shape,
/*axis=*/0, /*indices_are_nd=*/false, type,
DT_INT32, builder, &gather));
ctx->SetOutput(0, gather);
}
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp);
};
REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp);
class RandomUniformIntOp : public XlaOpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
@ -127,13 +202,8 @@ class TruncatedNormalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::Shape xla_element_shape =
xla::ShapeUtil::MakeShape(xla_shape.element_type(), {});
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
xla::XlaOp stddev = XlaHelpers::One(b, dtype);
xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape);
auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) {
return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0);
@ -151,34 +221,38 @@ class TruncatedNormalOp : public XlaOpKernel {
// out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd
// candidate = select(out_of_range_mask, rng_normal(), candidate)
// }
std::unique_ptr<xla::XlaBuilder> test_builder =
b->CreateSubBuilder("truncated_normal_test");
{
auto* b = test_builder.get();
xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
out_of_range_mask(candidate, b);
OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status());
}
std::unique_ptr<xla::XlaBuilder> body_builder =
b->CreateSubBuilder("truncated_normal_body");
{
auto* b = body_builder.get();
xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
xla::XlaOp to_resample = out_of_range_mask(candidate, b);
std::vector<xla::XlaOp> initial_values = {
// The current candidate.
b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()),
// The to_resample mask, where 'true' identifies a location in the
// current candidate that is out of range and must be regenerated.
b->Broadcast(b->ConstantR0<bool>(true), shape.dim_sizes()),
// Is any element in the mask true?
b->ConstantR0<bool>(true)};
auto condition = [&](gtl::ArraySlice<xla::XlaOp> values,
xla::XlaBuilder* b) -> xla::StatusOr<xla::XlaOp> {
// Continue while any element in the mask is true.
return values[2];
};
auto body =
[&](gtl::ArraySlice<xla::XlaOp> values,
xla::XlaBuilder* b) -> xla::StatusOr<std::vector<xla::XlaOp>> {
xla::XlaOp candidate = values[0];
xla::XlaOp to_resample = values[1];
xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
xla::XlaOp stddev = XlaHelpers::One(b, dtype);
b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate);
}
xla::StatusOr<xla::XlaComputation> test_computation = test_builder->Build();
OP_REQUIRES_OK(ctx, test_computation.status());
xla::StatusOr<xla::XlaComputation> body_computation = body_builder->Build();
OP_REQUIRES_OK(ctx, body_computation.status());
xla::XlaOp result = b->While(test_computation.ValueOrDie(),
body_computation.ValueOrDie(), candidate);
ctx->SetOutput(0, result);
candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape),
candidate);
// Compute a new to_resample mask, and determine whether any value is
// still out of range.
to_resample = out_of_range_mask(candidate, b);
TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b));
return std::vector<xla::XlaOp>{candidate, to_resample, done};
};
auto result =
XlaWhileLoop(condition, body, initial_values, "truncated_normal", b);
OP_REQUIRES_OK(ctx, result.status());
ctx->SetOutput(0, result.ValueOrDie()[0]);
}
};

View File

@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("Shape"), ShapeOp);
REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);
REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel {
}
};
REGISTER_XLA_OP(Name("Rank"), RankOp);
REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel {
}
};
REGISTER_XLA_OP(Name("Size"), SizeOp);
REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel {
if (!wrapped_squeeze_dims.empty()) {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
errors::InvalidArgument("Tried to explicitly squeeze "
"dimension ",
i, " but dimension was not 1: ",
existing_dim));
errors::InvalidArgument(
"Tried to explicitly squeeze dimension ", i,
" but dimension was not 1: ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);

View File

@ -110,7 +110,6 @@ xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
FloatLiteral(body_builder, a_shape.element_type(), 0.5));
// a[..., i+1:, i]
auto ip1 = body_builder->Add(i, body_builder->ConstantR0<int32>(1));
// select the whole i-th column, then mask out all rows above i+1
TF_ASSIGN_OR_RETURN(
auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));

View File

@ -40,6 +40,37 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
return Status::OK();
}
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal) {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
*literal = xla::BorrowingLiteral(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
return Status::OK();
}
Status HostTensorsToBorrowingLiteralTuple(
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
xla::BorrowingLiteral* literal) {
std::vector<const char*> buf_ptrs;
buf_ptrs.reserve(host_tensors.size());
std::vector<xla::Shape> tensor_shapes(host_tensors.size());
for (int i = 0; i < host_tensors.size(); i++) {
// Validate runtime shapes and fail if it doesn't match the contract.
const Tensor* tensor = &host_tensors[i];
buf_ptrs.emplace_back(static_cast<const char*>(DMAHelper::base(tensor)));
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(),
&tensor_shapes[i]));
}
*literal = xla::BorrowingLiteral(
buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes));
return Status::OK();
}
Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal,
Tensor* host_tensor) {
TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@ -29,6 +30,17 @@ namespace tensorflow {
// unsupported type.
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by
// 'host_tensor'.
Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
xla::BorrowingLiteral* literal);
// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
// owned by 'host_tensors'.
Status HostTensorsToBorrowingLiteralTuple(
tensorflow::gtl::ArraySlice<Tensor> host_tensors,
xla::BorrowingLiteral* literal);
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
// type <target_type>.
// Fails if the literal's primitive type !=

View File

@ -225,7 +225,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// Computes the XLA shape for argument 'arg'.
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
bool is_entry_computation,
xla::Shape* xla_shape) {
xla::Shape* xla_shape) const {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
@ -652,6 +652,7 @@ Status XlaCompiler::CompileSingleOp(
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
FixupSourceAndSinkEdges(graph.get());
return CompileGraph(options, name, std::move(graph), args, result);
}
@ -675,8 +676,8 @@ string ValidateFunctionDef(const FunctionDef* fdef,
return tensorflow::str_util::Join(invalid_ops, ", ");
}
// Check that the graph doesn't have any nodes incompatible with given
// device_type.
// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
// given device_type, invalid data type, missing attributes...)
Status ValidateGraph(const Graph* graph,
const FunctionLibraryDefinition& flib_def,
const DeviceType& device_type, const string& name) {
@ -694,6 +695,12 @@ Status ValidateGraph(const Graph* graph,
}
continue;
}
const OpDef* op_def;
if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) {
invalid_ops.push_back(node->def().op());
continue;
}
TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
invalid_ops.push_back(node->def().op());
}
@ -731,8 +738,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
graph.get(), local_flib_def_.get()));
// Detect ops incompatible with the device_type.
// FunctionalizeControlFlow may remove some unsupported ops.
// Detect invalid nodes.
// FunctionalizeControlFlow may remove some nodes from the graph.
TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
options_.device_type, name));

View File

@ -314,7 +314,7 @@ class XlaCompiler {
// See the class comment for more details about the argument passing
// convention.
Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
xla::Shape* xla_shape);
xla::Shape* xla_shape) const;
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -1025,5 +1026,66 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
<< status.error_message();
}
// Tests a graph which has a node with invalid data type.
TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
NodeDef shape;
shape.set_name("Shape");
shape.set_op("Shape");
(*shape.mutable_attr())["T"].set_type(DT_INT32);
(*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
Status status;
Node* shape_node = graph->AddNode(shape, &status);
TF_ASSERT_OK(status);
graph->AddControlEdge(graph->source_node(), shape_node);
std::vector<XlaCompiler::Argument> args;
XlaCompiler::CompilationResult result;
XlaCompiler compiler(DefaultOptions());
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"is not in the list of allowed values"))
<< status.error_message();
}
TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
NodeDef no_op;
no_op.set_name("NoOp");
no_op.set_op("NoOp");
Status status;
graph->AddNode(no_op, &status);
TF_ASSERT_OK(status);
std::vector<XlaCompiler::Argument> args;
XlaCompiler compiler(DefaultOptions());
// No control edge linking NoOp with source/sink.
{
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result);
ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.error_message(),
"The following nodes are unreachable "
"from the source in the graph: NoOp"))
<< status.error_message();
}
// Fix control edges for NoOp.
{
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result));
EXPECT_EQ(0, result.resource_updates.size());
}
}
} // namespace
} // namespace tensorflow

View File

@ -19,11 +19,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(dtype));
}
xla::Literal linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
xla::BorrowingLiteral linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
*iota = builder->ConstantLiteral(linspace_literal);
return Status::OK();
}
@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(index_type));
}
xla::Literal linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
xla::BorrowingLiteral linspace_literal;
TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
// Broadcast the linspace constant across the indices along the new axis,
// and test equality at each position.

View File

@ -53,7 +53,6 @@ xla_proto_library(
deps = [
":xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:session_proto",
],
)
@ -499,37 +498,6 @@ cc_library(
],
)
cc_library(
name = "scanner",
srcs = ["scanner.cc"],
hdrs = ["scanner.h"],
visibility = [":internal"],
deps = [
":status",
":status_macros",
":types",
":util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "scanner_test",
srcs = ["scanner_test.cc"],
deps = [
":scanner",
":status",
":status_macros",
":test",
":types",
":util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "text_literal_reader",
srcs = ["text_literal_reader.cc"],

View File

@ -86,6 +86,7 @@ cc_library(
hdrs = ["executable_build_options.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
@ -109,6 +110,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:source_map_util",

View File

@ -153,8 +153,6 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;

View File

@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const {
return dump_optimized_hlo_proto_to_;
}
ExecutableBuildOptions&
ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to(
tensorflow::StringPiece dirpath) {
dump_unoptimized_hlo_proto_to_ = dirpath.ToString();
return *this;
}
const tensorflow::gtl::optional<string>&
ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const {
return dump_unoptimized_hlo_proto_to_;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to(
tensorflow::StringPiece dirpath) {
dump_per_pass_hlo_proto_to_ = dirpath.ToString();

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/optional.h"
@ -64,6 +65,13 @@ class ExecutableBuildOptions {
tensorflow::StringPiece dirpath);
const tensorflow::gtl::optional<string>& dump_optimized_hlo_proto_to() const;
// If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO
// protobuf to (as in DebugOptions).
ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to(
tensorflow::StringPiece dirpath);
const tensorflow::gtl::optional<string>& dump_unoptimized_hlo_proto_to()
const;
// If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs
// to (as in DebugOptions).
ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to(
@ -76,6 +84,13 @@ class ExecutableBuildOptions {
ExecutableBuildOptions& set_hlo_profile(bool enabled);
tensorflow::gtl::optional<bool> hlo_profile() const;
void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) {
disabled_hlo_passes_.push_back(std::string(pass_name));
}
const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const {
return disabled_hlo_passes_;
}
// Returns a string representation of the build options, suitable for
// debugging.
string ToString() const;
@ -87,8 +102,10 @@ class ExecutableBuildOptions {
bool result_layout_set_ = false;
tensorflow::gtl::optional<string> generate_hlo_graph_;
tensorflow::gtl::optional<string> dump_optimized_hlo_proto_to_;
tensorflow::gtl::optional<string> dump_unoptimized_hlo_proto_to_;
tensorflow::gtl::optional<string> dump_per_pass_hlo_proto_to_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
std::vector<std::string> disabled_hlo_passes_;
};
} // namespace xla

View File

@ -185,7 +185,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
run_options, backend_->StreamBorrower(),
backend_->eigen_intra_op_thread_pool());
if (executable_->dumping()) {
if (executable_->dumping_snapshot()) {
return ExecuteAndDump(&service_options, arguments);
}
return executable_->ExecuteOnStreamWrapper(
@ -195,36 +195,36 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
executable_->session_module()->set_execution_platform(
executable_->hlo_snapshot()->set_execution_platform(
backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module()));
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result,
executable_->ExecuteOnStream(run_options, arguments,
/*hlo_execution_profile=*/nullptr));
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module()));
TF_RETURN_IF_ERROR(executable_->DumpSessionModule());
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot());
return std::move(result);
}
Status LocalExecutable::RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
SessionModule* session_module) {
session_module->clear_arguments();
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*argument));
*session_module->add_arguments() = literal->ToProto();
*hlo_snapshot->add_arguments() = literal->ToProto();
}
return Status::OK();
}
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
SessionModule* session_module) {
session_module->clear_result();
HloSnapshot* hlo_snapshot) {
hlo_snapshot->clear_result();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*result));
*session_module->mutable_result() = literal->ToProto();
*hlo_snapshot->mutable_result() = literal->ToProto();
return Status::OK();
}
@ -304,6 +304,11 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
shaped_buffer);
}
StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
const GlobalDataHandle& data, int replica_number) {
return local_service_->GlobalDataToShapedBuffer(data, replica_number);
}
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -78,11 +79,10 @@ class LocalExecutable {
// proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
SessionModule* session_module);
HloSnapshot* hlo_snapshot);
// Records the result of the computation in a SessionModule proto.
Status RecordResult(const ShapedBuffer* result,
SessionModule* session_module);
Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
@ -136,6 +136,11 @@ class LocalClient : public Client {
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer);
// Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
// as long as the handle is valid.
StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
const GlobalDataHandle& data, int replica_number);
// Transfer the given literal to the infeed queue of the given device.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with

View File

@ -37,7 +37,6 @@ cc_library(
],
)
# TODO(b/74197823): Replace computation_builder with xla_builder.
cc_library(
name = "xla_builder",
srcs = ["xla_builder.cc"],

View File

@ -1613,13 +1613,35 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
auto b = CreateSubBuilder("sum");
b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{},
/*channel_id=*/tensorflow::gtl::nullopt);
});
}
XlaOp XlaBuilder::CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!replica_group_ids.empty() || channel_id.has_value()) {
return Unimplemented(
"replica_group_ids and channel_id and is not supported in AllReduce");
}
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
{operand});
});

View File

@ -532,6 +532,29 @@ class XlaBuilder {
// supply one input to the sum and all replicas receive the resulting sum.
XlaOp CrossReplicaSum(const XlaOp& operand);
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
// broadcasting the reduction result to those cores. The reduction function is
// defined by `computation`, which should be a commutative computation on
// scalars, e.g., add, min, or max. The way that AllReduce is applied is
// configured by:
//
// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
// replicas belong to one group. Allreduce will be applied within subgroups.
// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
//
// - `channel_id`: for Allreduce nodes from different models, if they have the
// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
// applied cross models.
//
// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
XlaOp CrossReplicaSum(
const XlaOp& operand, const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
const tensorflow::gtl::optional<ChannelHandle>& channel_id =
tensorflow::gtl::nullopt);
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,

View File

@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} // namespace
/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) {
// Opaque and token types have empty layouts.
return Layout();
}
// A Layout proto corresponds to a single array, not a tuple.
DCHECK(!ShapeUtil::IsTuple(shape));
CHECK(ShapeUtil::IsArray(shape));
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
SetToDefaultLayout(&element_shape);
}
shape->clear_layout();
} else if (ShapeUtil::IsOpaque(*shape)) {
shape->clear_layout();
} else {
} else if (ShapeUtil::IsArray(*shape)) {
shape->mutable_layout()->set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape->dimensions_size(), 0);
SetDefaultLayoutToContainer(minor_to_major);
} else {
// Opaque, token types etc. have no layout.
shape->clear_layout();
}
}
@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
}
return Status::OK();
} else if (ShapeUtil::IsOpaque(shape)) {
if (shape.has_layout()) {
return InvalidArgument("opaque should not have a layout field");
}
return Status::OK();
} else {
// Array shape.
} else if (ShapeUtil::IsArray(shape)) {
if (!shape.has_layout()) {
return InvalidArgument("shape %s does not have a layout",
ShapeUtil::HumanString(shape).c_str());
}
return ValidateLayoutForShape(shape.layout(), shape);
} else {
// Token, opaque, etc. shape.
if (shape.has_layout()) {
return InvalidArgument(
"shape of primitive type %s should not have a layout",
PrimitiveType_Name(shape.element_type()).c_str());
}
return Status::OK();
}
}
@ -181,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return InvalidArgument("a single Layout is not valid for tuple shapes");
}
if (ShapeUtil::IsOpaque(shape)) {
return Status::OK();
if (!ShapeUtil::IsArray(shape)) {
return InvalidArgument(
"shape of primitive type %s should not have a layout",
PrimitiveType_Name(shape.element_type()).c_str());
}
if (layout.format() == INVALID_FORMAT) {
@ -273,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) ||
if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
shape.layout().padded_dimensions_size() == 0) {
return false;
}
@ -323,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
// Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
[](const Shape& s) { return HasLayout(s); });
} else if (ShapeUtil::IsOpaque(shape)) {
} else if (!ShapeUtil::IsArray(shape)) {
// Opaque, token types etc. ignore layout.
return true;
}
return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
@ -432,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
const Shape& rhs) {
if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
return false;
}
if (ShapeUtil::IsTuple(lhs)) {
if (ShapeUtil::TupleElementCount(lhs) !=
ShapeUtil::TupleElementCount(rhs)) {
if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) !=
ShapeUtil::TupleElementCount(rhs)) {
return false;
}
for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
@ -446,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
}
return true;
} else {
} else if (ShapeUtil::IsArray(lhs)) {
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
LayoutUtil::Equal(lhs.layout(), rhs.layout());
} else {
// Layouts of non-array and non-tuple shapes is ignored.
return true;
}
}

View File

@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
"elements, but shape is rank"));
}
TEST_F(LayoutUtilTest, CopyTokenLayout) {
Shape src = ShapeUtil::MakeTokenShape();
Shape dst = ShapeUtil::MakeTokenShape();
// Layouts are trivially the same for token types and copying layouts should
// be a nop.
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
Shape src = ShapeUtil::MakeOpaqueShape();
Shape dst = ShapeUtil::MakeOpaqueShape();
// Layouts are trivially the same for opaque types and copying layouts should
// be a nop.
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
Shape src = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
Shape dst = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
}
TEST_F(LayoutUtilTest, ClearLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) {
EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
}
TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
// Opaque and token types trivially have layouts.
for (Shape shape :
{ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
LayoutUtil::ClearLayout(&shape);
EXPECT_TRUE(LayoutUtil::HasLayout(shape));
}
}
TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),

View File

@ -317,15 +317,7 @@ class NearComparator {
rel_error = std::numeric_limits<float>::infinity();
} else {
abs_error = FpAbsoluteValue(actual - expected);
// If the expected result is exactly zero, don't compute relative error;
// that's meaningless.
//
// TODO(b/80321728): Come up with a better way to handle this case.
if (expected == NativeT{}) {
rel_error = 0;
} else {
rel_error = abs_error / FpAbsoluteValue(expected);
}
rel_error = abs_error / FpAbsoluteValue(expected);
}
const bool is_abs_mismatch = abs_error > error_.abs;
const bool is_rel_mismatch = rel_error > error_.rel;

View File

@ -987,6 +987,23 @@ std::unique_ptr<Literal> LiteralBase::Transpose(
return new_literal;
}
template <typename NativeT>
std::unique_ptr<Literal> LiteralBase::SliceInternal(
const Shape& result_shape,
tensorflow::gtl::ArraySlice<int64> start_indices) const {
auto result_literal = MakeUnique<Literal>(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
result_literal->EachCell<NativeT>(
[&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
NativeT value = Get<NativeT>(new_indices);
result_literal->Set<NativeT>(indices, value);
});
return result_literal;
}
std::unique_ptr<Literal> LiteralBase::Slice(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices) const {
@ -1004,51 +1021,17 @@ std::unique_ptr<Literal> LiteralBase::Slice(
const auto result_shape =
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
auto result_literal = MakeUnique<Literal>(result_shape);
DimensionVector new_indices(ShapeUtil::Rank(result_shape));
switch (result_shape.element_type()) {
case F32:
result_literal->EachCell<float>(
[&](tensorflow::gtl::ArraySlice<int64> indices, float /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
float value = Get<float>(new_indices);
result_literal->Set<float>(indices, value);
});
return result_literal;
return SliceInternal<float>(result_shape, start_indices);
case BF16:
return SliceInternal<bfloat16>(result_shape, start_indices);
case C64:
result_literal->EachCell<complex64>(
[&](tensorflow::gtl::ArraySlice<int64> indices, complex64 /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
complex64 value = Get<complex64>(new_indices);
result_literal->Set<complex64>(indices, value);
});
return result_literal;
return SliceInternal<complex64>(result_shape, start_indices);
case S32:
result_literal->EachCell<int32>(
[&](tensorflow::gtl::ArraySlice<int64> indices, int32 /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
int32 value = Get<int32>(new_indices);
result_literal->Set<int32>(indices, value);
});
return result_literal;
return SliceInternal<int32>(result_shape, start_indices);
case U32:
result_literal->EachCell<uint32>(
[&](tensorflow::gtl::ArraySlice<int64> indices, uint32 /*value*/) {
for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
new_indices[i] = indices[i] + start_indices[i];
}
uint32 value = Get<uint32>(new_indices);
result_literal->Set<uint32>(indices, value);
});
return result_literal;
return SliceInternal<uint32>(result_shape, start_indices);
default:
LOG(FATAL) << "not yet implemented: "
<< PrimitiveType_Name(result_shape.element_type());
@ -2358,28 +2341,28 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal,
: LiteralBase(), root_piece_(&literal.piece(view_root)) {}
BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
: LiteralBase(), shape_(shape) {
CHECK(ShapeUtil::IsArray(shape_));
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
CHECK(ShapeUtil::IsArray(*shape_));
CHECK_NE(src_buf_ptr, nullptr);
CHECK(LayoutUtil::HasLayout(shape_));
CHECK(LayoutUtil::HasLayout(*shape_));
root_piece_ = Piece();
root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
root_piece_.set_subshape(&shape_);
root_piece_.set_subshape(shape_.get());
}
BorrowingLiteral::BorrowingLiteral(
tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
: LiteralBase(), shape_(shape) {
CHECK(ShapeUtil::IsTuple(shape_));
CHECK(!ShapeUtil::IsNestedTuple(shape_));
CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_));
: LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
CHECK(ShapeUtil::IsTuple(*shape_));
CHECK(!ShapeUtil::IsNestedTuple(*shape_));
CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
root_piece_ = Piece();
root_piece_.set_subshape(&shape_);
BuildPieceSubtree(shape_, &root_piece_);
root_piece_.set_subshape(shape_.get());
BuildPieceSubtree(*shape_, &root_piece_);
for (int i = 0; i < src_buf_ptrs.size(); ++i) {
const auto& src_shape = shape_.tuple_shapes(i);
const auto& src_shape = shape_->tuple_shapes(i);
CHECK(ShapeUtil::IsArray(src_shape));
root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
}

View File

@ -542,6 +542,12 @@ class LiteralBase {
friend class Literal;
friend class LiteralSlice;
friend class BorrowingLiteral;
private:
template <typename NativeT>
std::unique_ptr<Literal> SliceInternal(
const Shape& result_shape,
tensorflow::gtl::ArraySlice<int64> start_indices) const;
};
// Class representing literal values in XLA.
@ -1093,8 +1099,10 @@ class BorrowingLiteral : public LiteralBase {
const Piece& root_piece() const override { return root_piece_; };
Piece root_piece_;
// Shape of this literal.
const Shape shape_;
// Shape of this literal. Stored as unique_ptr so such that the (default)
// move construction of this class would be trivially correct: the pointer to
// Shape root_piece_ stores will still point to the correct address.
std::unique_ptr<Shape> shape_;
};
template <typename NativeT>

View File

@ -1431,7 +1431,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
std::vector<int64> int64_values = {1, 2, 3};
const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
@ -1443,7 +1443,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
EXPECT_EQ(literal.Get<int64>({2}), 3);
}
TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) {
TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
std::vector<int64> one_two_three = {1, 2, 3};
const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});

View File

@ -12,6 +12,7 @@ py_library(
deps = [
":pywrap_xla",
"//tensorflow/compiler/xla:xla_data_proto_py",
"//tensorflow/compiler/xla/service:hlo_proto_py",
],
)
@ -53,6 +54,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",

View File

@ -276,6 +276,15 @@ const XlaComputation& LocalComputation::computation() const {
return computation_;
}
string LocalComputation::GetSerializedProto() const {
string result;
if (!computation_.proto().SerializeToString(&result)) {
LOG(ERROR) << "Failed to serialize the HloModuleProto.";
return "";
}
return result;
}
StatusOr<Shape> LocalComputation::GetReturnValueShape() const {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
computation_.GetProgramShape());
@ -589,10 +598,12 @@ _FORWARD_BINOP(Or)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)

View File

@ -112,6 +112,11 @@ class LocalComputation {
const XlaComputation& computation() const;
// Returns the HloModuleProto contained in the XlaComputation in the
// serialized binary format. Logs an internal error and returns an empty
// string on failure.
string GetSerializedProto() const;
// Returns the return-value shape for this computation.
StatusOr<Shape> GetReturnValueShape() const;
@ -300,10 +305,12 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)

View File

@ -851,6 +851,11 @@ tensorflow::ImportNumpy();
})) {
return nullptr;
}
if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) {
build_options.set_dump_unoptimized_hlo_proto_to(std::move(s));
})) {
return nullptr;
}
if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) {
build_options.set_dump_per_pass_hlo_proto_to(std::move(s));
})) {
@ -906,6 +911,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputation;
%unignore xla::swig::LocalComputation::Compile;
%unignore xla::swig::LocalComputation::GetReturnValueShape;
%unignore xla::swig::LocalComputation::GetSerializedProto;
%unignore xla::swig::LocalOp;
%unignore xla::swig::LocalComputationBuilder;
%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
@ -968,10 +974,12 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp;
%unignore xla::swig::LocalComputationBuilder::Expm1;
%unignore xla::swig::LocalComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil;
%unignore xla::swig::LocalComputationBuilder::Round;
%unignore xla::swig::LocalComputationBuilder::Log;
%unignore xla::swig::LocalComputationBuilder::Log1p;
%unignore xla::swig::LocalComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;

View File

@ -28,6 +28,7 @@ import numpy as np
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.compiler.xla.python import pywrap_xla as c_api
from tensorflow.compiler.xla.service import hlo_pb2
# Most functions are snake_case for consistency with other modules, whereas
@ -88,10 +89,12 @@ _UNARY_OPS = [
'Not',
'Abs',
'Exp',
'Expm1',
'Floor',
'Round',
'Ceil',
'Log',
'Log1p',
'Sign',
'Cos',
'Sin',
@ -352,6 +355,7 @@ class CompileOptions(object):
def __init__(self):
self.generate_hlo_graph = None
self.dump_optimized_hlo_proto_to = None
self.dump_unoptimized_hlo_proto_to = None
self.dump_per_pass_hlo_proto_to = None
self.hlo_profile = False
@ -410,6 +414,17 @@ class LocalComputation(object):
assert isinstance(c_local_computation, c_api.LocalComputation)
self._delete = c_api.DeleteLocalComputation
def GetProto(self):
"""Get the HloModuleProto proto object in this local computation.
Returns:
An HloModuleProto proto object that has the whole-graph information.
"""
serialized = self.c_local_computation.GetSerializedProto()
proto = hlo_pb2.HloModuleProto.FromString(serialized)
return proto
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
"""Compiles an un-compiled local computation.
@ -1100,6 +1115,61 @@ class ComputationBuilder(object):
dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
lhs: LocalOp for the rank N+2 array of inputs.
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of integer dilation factors.
rhs_dilation: length-N array-like of integer dilation factors.
dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a
triple (lhs_spec, rhs_spec, out_spec) where each element is a string of
length N+2 identifying by position (1) batch dimensions in lhs, rhs, and
the output with the character 'N', (2) feature dimensions in lhs and the
output with the character 'C', (3) input and output feature dimensions
in rhs with the characters 'I' and 'O' respectively, and (4) spatial
dimension correspondences between lhs, rhs, and the output using any
distinct characters. For example, to indicate dimension numbers
consistent with the Conv operation with two spatial dimensions, one
could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate
dimension numbers consistent with the TensorFlow Conv2D operation, one
could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of
convolution dimension specification, window strides are associated with
spatial dimension character labels according to the order in which the
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
if not isinstance(dimension_numbers,
xla_data_pb2.ConvolutionDimensionNumbers):
lhs_spec, rhs_spec, out_spec = dimension_numbers
dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
dimension_numbers.input_batch_dimension = lhs_spec.index('N')
dimension_numbers.input_feature_dimension = lhs_spec.index('C')
dimension_numbers.output_batch_dimension = out_spec.index('N')
dimension_numbers.output_feature_dimension = out_spec.index('C')
dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
dimension_numbers.kernel_spatial_dimensions.extend(
i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
dimension_numbers.input_spatial_dimensions.extend(
sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
key=lambda i: rhs_spec.index(lhs_spec[i])))
dimension_numbers.output_spatial_dimensions.extend(
sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
dimension_numbers)
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.

View File

@ -164,6 +164,16 @@ class ComputationsWithConstantsTest(LocalComputationTest):
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]])
def testGetProto(self):
c = self._NewComputation()
c.Add(
c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])),
c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]])))
built = c.Build()
proto = built.GetProto() # HloModuleProto
self.assertTrue(len(proto.computations) == 1)
self.assertTrue(len(proto.computations[0].instructions) == 3)
def testSum2DF64(self):
c = self._NewComputation()
c.Add(
@ -509,6 +519,46 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NCHW", "OIHW", "NCHW")
c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation,
dimension_numbers)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=result)
def testConvGeneralDilatedPermutedF32(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
lhs = a(1, 1, 2, 3)
rhs = a(1, 1, 1, 2) * 10
strides = [1, 1]
pads = [(1, 0), (0, 1)]
lhs_dilation = (2, 1)
rhs_dilation = (1, 1)
dimension_numbers = ("NHWC", "OIHW", "CWNH")
c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))),
c.Constant(rhs),
strides, pads, lhs_dilation, rhs_dilation,
dimension_numbers)
result = np.array([[[[0., 0., 0.],
[10., 20., 0.],
[0., 0., 0.],
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
@ -521,6 +571,12 @@ class SingleOpTest(LocalComputationTest):
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
def testExpm1(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Expm1(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.expm1(arr))
def testRound(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
@ -533,6 +589,12 @@ class SingleOpTest(LocalComputationTest):
c.Log(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log(arr))
def testLog1p(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Log1p(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log1p(arr))
def testNeg(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])

View File

@ -42,7 +42,7 @@ tf_cc_binary(
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@grpc//:grpc++_unsecure",
"@grpc//:grpc++",
],
)
@ -61,7 +61,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@grpc//:grpc++_unsecure",
"@grpc//:grpc++",
],
)
@ -74,6 +74,6 @@ cc_library(
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"@grpc//:grpc++_unsecure",
"@grpc//:grpc++",
],
)

View File

@ -20,8 +20,8 @@ limitations under the License.
#include <memory>
#include <vector>
#include "grpc++/create_channel.h"
#include "grpc++/security/credentials.h"
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"

View File

@ -32,19 +32,6 @@ namespace xla {
return tensorflow::ToGrpcStatus(s);
}
::grpc::Status GRPCService::Computation(::grpc::ServerContext* context,
const ComputationRequest* arg,
ComputationResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->Computation(arg, result); });
}
::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context,
const OpRequest* arg, OpResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->Op(arg, result); });
}
::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context,
const UnregisterRequest* arg,
UnregisterResponse* result) {
@ -60,21 +47,6 @@ namespace xla {
});
}
::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context,
const SetReturnValueRequest* arg,
SetReturnValueResponse* results) {
return DelegateRPC([this, arg, results]() {
return service_->SetReturnValue(arg, results);
});
}
::grpc::Status GRPCService::Execute(::grpc::ServerContext* context,
const ExecuteRequest* arg,
ExecuteResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->Execute(arg, result); });
}
::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/,
const ExecuteGraphRequest* arg,
ExecuteResponse* result) {
@ -82,13 +54,6 @@ namespace xla {
[this, arg, result]() { return service_->ExecuteGraph(arg, result); });
}
::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context,
const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->ExecuteAsync(arg, result); });
}
::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context,
const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) {
@ -136,20 +101,6 @@ namespace xla {
[this, arg, result]() { return service_->ResetDevice(arg, result); });
}
::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context,
const IsConstantRequest* arg,
IsConstantResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->IsConstant(arg, result); });
}
::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context,
const ComputeConstantRequest* arg,
ComputeConstantResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->ComputeConstant(arg, result); });
}
::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context,
const GetShapeRequest* arg,
GetShapeResponse* result) {
@ -157,43 +108,4 @@ namespace xla {
[this, arg, result]() { return service_->GetShape(arg, result); });
}
::grpc::Status GRPCService::GetComputationShape(
::grpc::ServerContext* context, const GetComputationShapeRequest* arg,
GetComputationShapeResponse* result) {
return DelegateRPC([this, arg, result]() {
return service_->GetComputationShape(arg, result);
});
}
::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context,
const GetLocalShapeRequest* arg,
GetLocalShapeResponse* result) {
return DelegateRPC(
[this, arg, result]() { return service_->GetLocalShape(arg, result); });
}
::grpc::Status GRPCService::GetComputationStats(
::grpc::ServerContext* context, const ComputationStatsRequest* arg,
ComputationStatsResponse* result) {
return DelegateRPC([this, arg, result]() {
return service_->GetComputationStats(arg, result);
});
}
::grpc::Status GRPCService::SnapshotComputation(
::grpc::ServerContext* context, const SnapshotComputationRequest* arg,
SnapshotComputationResponse* result) {
return DelegateRPC([this, arg, result]() {
return service_->SnapshotComputation(arg, result);
});
}
::grpc::Status GRPCService::LoadComputationSnapshot(
::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg,
LoadComputationSnapshotResponse* result) {
return DelegateRPC([this, arg, result]() {
return service_->LoadComputationSnapshot(arg, result);
});
}
} // namespace xla

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_
#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_
#include "grpc++/server_context.h"
#include "grpcpp/server_context.h"
#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service {
static StatusOr<std::unique_ptr<GRPCService>> NewService(
se::Platform* platform = nullptr);
::grpc::Status Computation(::grpc::ServerContext* context,
const ComputationRequest* arg,
ComputationResponse* result) override;
::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg,
OpResponse* result) override;
::grpc::Status Unregister(::grpc::ServerContext* context,
const UnregisterRequest* arg,
UnregisterResponse* result) override;
@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service {
const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
::grpc::Status SetReturnValue(::grpc::ServerContext* context,
const SetReturnValueRequest* arg,
SetReturnValueResponse* results) override;
::grpc::Status Execute(::grpc::ServerContext* context,
const ExecuteRequest* arg,
ExecuteResponse* result) override;
::grpc::Status ExecuteGraph(::grpc::ServerContext* context,
const ExecuteGraphRequest* arg,
ExecuteResponse* result) override;
::grpc::Status ExecuteAsync(::grpc::ServerContext* context,
const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) override;
::grpc::Status WaitForExecution(::grpc::ServerContext* context,
const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override;
@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service {
const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
::grpc::Status IsConstant(::grpc::ServerContext* context,
const IsConstantRequest* arg,
IsConstantResponse* result) override;
::grpc::Status ComputeConstant(::grpc::ServerContext* context,
const ComputeConstantRequest* arg,
ComputeConstantResponse* result) override;
::grpc::Status GetShape(::grpc::ServerContext* context,
const GetShapeRequest* arg,
GetShapeResponse* result) override;
::grpc::Status GetComputationShape(
::grpc::ServerContext* context, const GetComputationShapeRequest* arg,
GetComputationShapeResponse* result) override;
::grpc::Status GetLocalShape(::grpc::ServerContext* context,
const GetLocalShapeRequest* arg,
GetLocalShapeResponse* result) override;
::grpc::Status GetComputationStats(::grpc::ServerContext* context,
const ComputationStatsRequest* arg,
ComputationStatsResponse* result) override;
::grpc::Status SnapshotComputation(
::grpc::ServerContext* context, const SnapshotComputationRequest* arg,
SnapshotComputationResponse* result) override;
::grpc::Status LoadComputationSnapshot(
::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg,
LoadComputationSnapshotResponse* result) override;
private:
std::unique_ptr<::xla::Service> service_;

View File

@ -15,9 +15,9 @@ limitations under the License.
// Basic server binary that exposes a xla::Service through a GRPC interface
// on a configurable port.
#include "grpc++/security/server_credentials.h"
#include "grpc++/server.h"
#include "grpc++/server_builder.h"
#include "grpcpp/security/server_credentials.h"
#include "grpcpp/server.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"

View File

@ -62,21 +62,6 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
});
}
Status GRPCStub::LoadComputationSnapshot(
const LoadComputationSnapshotRequest* request,
LoadComputationSnapshotResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->LoadComputationSnapshot(context, *request, response);
});
}
Status GRPCStub::Execute(const ExecuteRequest* request,
ExecuteResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->Execute(context, *request, response);
});
}
Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
ExecuteResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@ -84,13 +69,6 @@ Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
});
}
Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request,
ExecuteParallelResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ExecuteParallel(context, *request, response);
});
}
Status GRPCStub::ExecuteGraphParallel(
const ExecuteGraphParallelRequest* request,
ExecuteParallelResponse* response) {
@ -99,13 +77,6 @@ Status GRPCStub::ExecuteGraphParallel(
});
}
Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request,
ExecuteAsyncResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ExecuteAsync(context, *request, response);
});
}
Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request,
WaitForExecutionResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@ -120,13 +91,6 @@ Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request,
});
}
Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request,
ComputationStatsResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetComputationStats(context, *request, response);
});
}
Status GRPCStub::GetComputationGraphStats(
const ComputationGraphStatsRequest* request,
ComputationStatsResponse* response) {
@ -135,13 +99,6 @@ Status GRPCStub::GetComputationGraphStats(
});
}
Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request,
GetComputationShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetComputationShape(context, *request, response);
});
}
Status GRPCStub::GetShape(const GetShapeRequest* request,
GetShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@ -163,48 +120,6 @@ Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request,
});
}
// Methods used by ComputationBuilder.
Status GRPCStub::Computation(const ComputationRequest* request,
ComputationResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->Computation(context, *request, response);
});
}
Status GRPCStub::Op(const OpRequest* request, OpResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->CreateOp(context, *request, response);
});
}
Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request,
GetLocalShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->GetLocalShape(context, *request, response);
});
}
Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request,
SetReturnValueResponse* responses) {
return MakeRPC([this, request, responses](::grpc::ClientContext* context) {
return grpc_stub_->SetReturnValue(context, *request, responses);
});
}
Status GRPCStub::IsConstant(const IsConstantRequest* request,
IsConstantResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->IsConstant(context, *request, response);
});
}
Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request,
ComputeConstantResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->ComputeConstant(context, *request, response);
});
}
Status GRPCStub::ComputeConstantGraph(
const ComputeConstantGraphRequest* request,
ComputeConstantResponse* response) {
@ -213,14 +128,6 @@ Status GRPCStub::ComputeConstantGraph(
});
}
// Methods used by Computation.
Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request,
SnapshotComputationResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
return grpc_stub_->SnapshotComputation(context, *request, response);
});
}
// Methods used by GlobalData.
Status GRPCStub::Unregister(const UnregisterRequest* request,
UnregisterResponse* response) {

View File

@ -43,39 +43,21 @@ class GRPCStub : public ServiceInterface {
Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
Status LoadComputationSnapshot(
const LoadComputationSnapshotRequest* request,
LoadComputationSnapshotResponse* result) override;
Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
Status ExecuteGraph(const ExecuteGraphRequest* request,
ExecuteResponse* response) override;
Status ExecuteParallel(const ExecuteParallelRequest* arg,
ExecuteParallelResponse* result) override;
Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request,
ExecuteParallelResponse* response) override;
Status ExecuteAsync(const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) override;
Status WaitForExecution(const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override;
Status DeconstructTuple(const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
Status GetComputationStats(const ComputationStatsRequest* arg,
ComputationStatsResponse* result) override;
Status GetComputationGraphStats(const ComputationGraphStatsRequest* request,
ComputationStatsResponse* response) override;
Status GetComputationShape(const GetComputationShapeRequest* arg,
GetComputationShapeResponse* result) override;
Status GetShape(const GetShapeRequest* arg,
GetShapeResponse* result) override;
@ -85,30 +67,9 @@ class GRPCStub : public ServiceInterface {
Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
CreateChannelHandleResponse* result) override;
// Methods used by ComputationBuilder.
Status Computation(const ComputationRequest* arg,
ComputationResponse* result) override;
Status Op(const OpRequest* arg, OpResponse* result) override;
Status GetLocalShape(const GetLocalShapeRequest* arg,
GetLocalShapeResponse* result) override;
Status SetReturnValue(const SetReturnValueRequest* arg,
SetReturnValueResponse* results) override;
Status IsConstant(const IsConstantRequest* arg,
IsConstantResponse* result) override;
Status ComputeConstant(const ComputeConstantRequest* arg,
ComputeConstantResponse* result) override;
Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
ComputeConstantResponse* result) override;
// Methods used by Computation.
Status SnapshotComputation(const SnapshotComputationRequest* ag,
SnapshotComputationResponse* result) override;
// Methods used by GlobalData.
Status Unregister(const UnregisterRequest* arg,
UnregisterResponse* result) override;

View File

@ -75,19 +75,7 @@ service XlaService {
rpc GetShape(GetShapeRequest) returns (GetShapeResponse) {
}
// Requests the program shape of the referenced computation.
rpc GetComputationShape(GetComputationShapeRequest)
returns (GetComputationShapeResponse) {
}
// Requests the statistics of the given computation.
rpc GetComputationStats(ComputationStatsRequest)
returns (ComputationStatsResponse) {
}
// Requests the statistics of the given computation.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
rpc GetComputationGraphStats(ComputationGraphStatsRequest)
returns (ComputationStatsResponse) {
}
@ -121,25 +109,12 @@ service XlaService {
rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) {
}
// Tests if an expression is a compile-time constant.
rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) {
}
// Computes the value of a constant expression.
rpc ComputeConstant(ComputeConstantRequest)
returns (ComputeConstantResponse) {
}
// Computes the value of a constant expression. The request contains the
// computation graph for the constant expression.
rpc ComputeConstantGraph(ComputeConstantGraphRequest)
returns (ComputeConstantResponse) {
}
// Retrieves the inferred shape for a value within a computation.
rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) {
}
// Requests one or more device handles from the target. The returned device
// handles can be used to specify the device on which to execute computations
// or transfer data.
@ -153,32 +128,6 @@ service XlaService {
returns (CreateChannelHandleResponse) {
}
// Requests that the referenced computation be specialized for the provided
// arguments for subsequent execution. This permits things such as value
// specialization.
rpc Specialize(SpecializeRequest) returns (SpecializeResponse) {
}
// Modifies the provided computation so that subsequent executions
// will compute the provided ComputationDataHandle, rather than the
// last expression enqueued on that Computation.
rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) {
}
// Computation creates a new computation with the given name.
// A unique ComputationHandle is returned.
rpc Computation(ComputationRequest) returns (ComputationResponse) {
}
// Adds a new op to a computation.
rpc CreateOp(OpRequest) returns (OpResponse) {
}
// Invokes the provided computation with the provided global data passed as
// immutable arguments. Returns global data output and execution timing.
rpc Execute(ExecuteRequest) returns (ExecuteResponse) {
}
// Invokes the provided computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
@ -188,38 +137,13 @@ service XlaService {
// Invokes the provided list of computations in parallel with the provided
// global data for each computation. Returns a list of global data output and
// execution timing.
rpc ExecuteParallel(ExecuteParallelRequest)
returns (ExecuteParallelResponse) {
}
// Invokes the provided list of computations in parallel with the provided
// global data for each computation. Returns a list of global data output and
// execution timing.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
rpc ExecuteGraphParallel(ExecuteGraphParallelRequest)
returns (ExecuteParallelResponse) {
}
// Invokes the provided computation with the provided global data passed as
// immutable arguments. Returns a handle to the execution.
rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) {
}
// Waits until the given execution (aysnchronously launched) is complete, and
// returns the global data output.
rpc WaitForExecution(WaitForExecutionRequest)
returns (WaitForExecutionResponse) {
}
// Serializes a computation to proto form, so it can be loaded via
// LoadComputationSnapshot.
rpc SnapshotComputation(SnapshotComputationRequest)
returns (SnapshotComputationResponse) {
}
// Loads a computation from a captured snapshot.
rpc LoadComputationSnapshot(LoadComputationSnapshotRequest)
returns (LoadComputationSnapshotResponse) {
}
}

View File

@ -1,197 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/scanner.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
// Returns true if c can be the first character in an identifier.
bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; }
// Returns true if c can be the non-first character in an identifier.
bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; }
// Returns true if str is an identifier.
bool IsIdentifier(tensorflow::StringPiece str) {
if (str.empty() || !IsIdentifierFirst(str[0])) {
return false;
}
for (int64 i = 1; i < str.size(); ++i) {
if (!IsIdentifierLater(str[i])) {
return false;
}
}
return true;
}
} // namespace
Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {}
bool Scanner::ok() const { return status().ok(); }
const Status& Scanner::status() const { return status_; }
bool Scanner::Match(tensorflow::StringPiece match) {
SkipWhitespace();
if (ok() && position_ + match.size() <= input_.size() &&
std::equal(match.begin(), match.end(), input_.begin() + position_)) {
SkipChars(match.size());
VLOG(10) << "Matched \"" << match << "\"";
return true;
} else {
return false;
}
}
void Scanner::Expect(tensorflow::StringPiece expect) {
if (!Match(expect)) {
SetError(tensorflow::strings::StrCat("Expected \"", expect, "\"."));
}
}
bool Scanner::MatchReadIdentifier(string* identifier) {
SkipWhitespace();
if (!IsIdentifierFirst(PeekChar())) {
return false;
}
identifier->clear();
do {
*identifier += ReadChar();
} while (IsIdentifierLater(PeekChar()));
VLOG(10) << "Read identifier " << identifier;
CHECK(IsIdentifier(*identifier));
return true;
}
string Scanner::ReadIdentifier() {
string identifier;
if (!MatchReadIdentifier(&identifier)) {
SetError("Expected identifier.");
}
return identifier;
}
void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) {
CHECK(IsIdentifier(expect));
string identifier;
if (!MatchReadIdentifier(&identifier)) {
SetError(tensorflow::strings::StrCat("Expected identifier ", expect, "."));
}
if (identifier != expect) {
SetError(tensorflow::strings::StrCat("Expected identifier ", expect,
", but got ", identifier, "."));
}
}
// Matches the end of the input, also known as End Of File (EOF).
bool Scanner::MatchEof() {
SkipWhitespace();
return PeekChar() == EOF;
}
void Scanner::ExpectEof() {
if (!MatchEof()) {
SetError("Expected end of input.");
}
}
// Reads a vector of the format "(1, 2, 3)".
std::vector<int64> Scanner::ReadIntVector() {
std::vector<int64> ints;
Expect("(");
if (!Match(")") && ok()) {
ints.push_back(ReadInt());
while (Match(",")) {
ints.push_back(ReadInt());
}
Expect(")");
}
VLOG(10) << "Read int vector with " << ints.size() << " elements.";
return ints;
}
int64 Scanner::ReadInt() {
bool negative = Match("-");
if (!PeekDigit()) {
SetError("Expected integer.");
return 0;
}
int64 integer = 0;
do {
integer = (ReadChar() - '0') + integer * 10;
} while (PeekDigit());
integer = negative ? -integer : integer;
VLOG(10) << "Read integer " << integer;
return integer;
}
void Scanner::SkipWhitespace() {
while (PeekWhitespace()) {
SkipChars(1);
}
}
int Scanner::ReadChar() {
int c = PeekChar();
SkipChars(1);
VLOG(20) << "Read char " << c;
return c;
}
int Scanner::PeekChar() const {
return ok() && position_ < input_.size() ? input_[position_] : EOF;
}
bool Scanner::PeekDigit() const {
// Do not use std::isdigit since it depends on the locale and we do not
// handle any digits beyond 0-9.
const char c = PeekChar();
return '0' <= c && c <= '9';
}
bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); }
bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); }
void Scanner::SkipChars(int64 count) {
CHECK_GE(count, 0);
position_ += count;
}
void Scanner::SetError(string error_message) {
// Only the first error is recorded since any later errors will likely be a
// consequence of the first error.
if (ok()) {
status_ = InvalidArgumentStrCat(std::move(error_message));
position_ = input_.size();
VLOG(10) << "Failed scanner with error " << status_.ToString();
} else {
VLOG(10) << "Error on already failed scanner is " << error_message;
}
}
} // namespace xla

View File

@ -1,102 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_
#define TENSORFLOW_COMPILER_XLA_SCANNER_H_
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
namespace xla {
// Simple class for parsing data. The concepts for the interface are:
//
// Match(x): Returns true if x is next in the input and in that case skips
// past it. Otherwise returns false.
//
// Expect(x): As Match(x), but requires x to be next in the input.
//
// MatchReadX(x): Returns true if an X is next in the input and in that case
// skips past it and assigns it to x. Otherwise returns false.
//
// ReadX(): As ReadMatchX(), but requires an X to be next in the input and
// returns it.
//
// PeekX(): Returns true if an X is next in the input and does not skip
// past it either way.
//
// All of these, except those that work on individual characters, skip
// whitespace.
//
// If a requirement is not met, the error is available in status(). A Scanner
// with a failed status() will behave as though the rest of the input is EOF and
// will not record further errors after that point.
class Scanner {
public:
Scanner(tensorflow::StringPiece input);
bool ok() const;
const Status& status() const;
bool Match(tensorflow::StringPiece match);
void Expect(tensorflow::StringPiece expect);
// Match-reads an identifier. An identifier starts with an alphabetic
// character or an underscore followed by any number of characters that are
// each alphanumeric or underscore.
bool MatchReadIdentifier(string* identifier);
string ReadIdentifier();
void ExpectIdentifier(tensorflow::StringPiece expect);
// Matches the end of the input, also known as End Of File (EOF).
bool MatchEof();
void ExpectEof();
// Reads a vector of the format "(1, 4, 5)".
std::vector<int64> ReadIntVector();
// Reads an integer. Can start with a minus but not a plus.
int64 ReadInt();
// Keeps skipping until encountering a non-whitespace character.
void SkipWhitespace();
// *** Below here are character-level methods that do not skip whitespace.
int ReadChar();
int PeekChar() const;
bool PeekDigit() const;
bool PeekAlnum() const;
bool PeekWhitespace() const;
// Skip past the next count characters.
void SkipChars(int64 count);
private:
// Sets a failed status. The input is in effect replaced with EOF after
// this. Only the first error is recorded.
void SetError(string error_message);
const tensorflow::StringPiece input_;
int64 position_;
Status status_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SCANNER_H_

View File

@ -1,124 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// TODO(b/80179519): Fix open source build for real.
#if 0
#include "tensorflow/compiler/xla/scanner.h"
#include <string>
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/env.h"
namespace xla {
namespace {
TEST(Scanner, Empty) {
Scanner scanner("");
EXPECT_EQ(scanner.PeekChar(), EOF);
EXPECT_TRUE(scanner.MatchEof());
EXPECT_TRUE(scanner.Match(""));
EXPECT_FALSE(scanner.Match("1"));
EXPECT_TRUE(scanner.ok());
}
TEST(Scanner, Prefix) {
Scanner scanner("1234 5");
EXPECT_FALSE(scanner.MatchEof());
EXPECT_TRUE(scanner.Match("12"));
EXPECT_TRUE(scanner.Match("34 "));
EXPECT_FALSE(scanner.MatchEof());
EXPECT_FALSE(scanner.Match("5 "));
EXPECT_TRUE(scanner.Match("5"));
EXPECT_TRUE(scanner.MatchEof());
}
TEST(Scanner, Whitespace) {
Scanner scanner(" \t\n\r 1\t2\n\n");
EXPECT_FALSE(scanner.Match(" "));
EXPECT_TRUE(scanner.Match("1"));
EXPECT_TRUE(scanner.Match("2"));
EXPECT_TRUE(scanner.MatchEof());
EXPECT_TRUE(scanner.ok());
}
TEST(Scanner, Fail) {
Scanner scanner("153 4q");
scanner.Expect("5");
EXPECT_FALSE(scanner.ok());
EXPECT_FALSE(scanner.status().ok());
EXPECT_TRUE(scanner.MatchEof());
}
TEST(Scanner, Identifier) {
Scanner scanner("1 q1 _1_ _1a= qqb");
string identifier = "foo";
EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier));
EXPECT_EQ(identifier, "foo");
scanner.Match("1");
EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier));
EXPECT_EQ(identifier, "q1");
scanner.ExpectIdentifier("_1_");
EXPECT_TRUE(scanner.ok());
scanner.ExpectIdentifier("_1a");
EXPECT_TRUE(scanner.ok());
// The = after _1a is not included in the identifier.
scanner.Expect("=");
// The expected identifier matches a prefix but is not the full identifier in
// the input.
EXPECT_TRUE(scanner.ok());
scanner.ExpectIdentifier("qq");
EXPECT_FALSE(scanner.ok());
}
TEST(Scanner, Int) {
Scanner scanner("1_2 3% -1 124345 -363 0 -0");
EXPECT_EQ(1, scanner.ReadInt());
EXPECT_TRUE(scanner.Match("_"));
EXPECT_EQ(2, scanner.ReadInt());
EXPECT_EQ(3, scanner.ReadInt());
EXPECT_TRUE(scanner.Match("%"));
EXPECT_EQ(-1, scanner.ReadInt());
EXPECT_EQ(124345, scanner.ReadInt());
EXPECT_EQ(-363, scanner.ReadInt());
EXPECT_EQ(0, scanner.ReadInt());
EXPECT_EQ(0, scanner.ReadInt());
EXPECT_TRUE(scanner.MatchEof());
}
TEST(Scanner, IntVector) {
Scanner scanner("()(0) (-1,2) ( 3 , 4 )");
EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty());
EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0));
EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2));
EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4));
EXPECT_TRUE(scanner.MatchEof());
EXPECT_TRUE(scanner.ok());
}
} // namespace
} // namespace xla
#endif

View File

@ -16,12 +16,9 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
xla_proto_library(
name = "session_proto",
srcs = ["session.proto"],
visibility = ["//visibility:public"],
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library_py",
)
xla_proto_library(
@ -31,6 +28,12 @@ xla_proto_library(
deps = ["//tensorflow/compiler/xla:xla_data_proto"],
)
tf_proto_library_py(
name = "hlo_proto", # bzl adds a _py suffix only to the OSS target.
srcs = ["hlo.proto"],
visibility = ["//visibility:public"],
)
xla_proto_library(
name = "hlo_profile_printer_data",
srcs = ["hlo_profile_printer_data.proto"],
@ -266,6 +269,7 @@ cc_library(
"dfs_hlo_visitor.cc",
"hlo_computation.cc",
"hlo_instruction.cc",
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
"hlo_sharding.cc",
@ -273,18 +277,21 @@ cc_library(
hdrs = [
"dfs_hlo_visitor.h",
"dfs_hlo_visitor_with_default.h",
"hlo_clone_context.h",
"hlo_computation.h",
"hlo_domain_metadata.h",
"hlo_instruction.h",
"hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
"hlo_sharding.h",
],
deps = [
":hlo_casting_utils",
":hlo_module_config",
":hlo_proto",
":hlo_reachability",
":name_uniquer",
":versioned_computation_handle",
"//tensorflow/compiler/xla:array",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
@ -297,6 +304,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:human_readable_json",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
@ -336,8 +344,8 @@ tf_cc_test(
":hlo",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@ -375,8 +383,8 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@ -386,33 +394,25 @@ tf_cc_test(
srcs = ["hlo_matchers_test.cc"],
deps = [
":hlo_matchers",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "versioned_computation_handle",
srcs = ["versioned_computation_handle.cc"],
hdrs = ["versioned_computation_handle.h"],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "hlo_instruction_test",
srcs = ["hlo_instruction_test.cc"],
deps = [
":hlo",
":hlo_parser",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@ -429,9 +429,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -534,45 +534,6 @@ tf_cc_test(
],
)
cc_library(
name = "user_computation",
srcs = ["user_computation.cc"],
hdrs = ["user_computation.h"],
deps = [
":hlo",
":session_proto",
":shape_inference",
":versioned_computation_handle",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "user_computation_test",
srcs = ["user_computation_test.cc"],
deps = [
":hlo_matchers",
":user_computation",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "platform_util",
srcs = ["platform_util.cc"],
@ -618,10 +579,8 @@ cc_library(
":allocation_tracker",
":backend",
":channel_tracker",
":compilation_cache",
":compiler",
":computation_layout",
":computation_tracker",
":device_memory_allocator",
":executable",
":execution_tracker",
@ -632,11 +591,8 @@ cc_library(
":hlo_module_config",
":hlo_proto_util",
":platform_util",
":session_proto",
":source_map_util",
":transfer_manager",
":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:service_interface",
@ -663,7 +619,6 @@ cc_library(
":backend",
":compiler",
":computation_layout",
":computation_tracker",
":device_memory_allocator",
":executable",
":hlo",
@ -672,8 +627,6 @@ cc_library(
":platform_util",
":service",
":shaped_buffer",
":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@ -697,7 +650,6 @@ cc_library(
":backend",
":compiler",
":computation_layout",
":computation_tracker",
":platform_util",
":service",
"//tensorflow/compiler/xla:status_macros",
@ -794,9 +746,7 @@ cc_library(
":hlo_graph_dumper",
":hlo_proto",
":pool",
":session_proto",
":shaped_buffer",
":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@ -892,34 +842,12 @@ cc_library(
],
)
cc_library(
name = "computation_tracker",
srcs = ["computation_tracker.cc"],
hdrs = ["computation_tracker.h"],
deps = [
":hlo",
":hlo_module_config",
":session_proto",
":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
cc_library(
name = "channel_tracker",
srcs = ["channel_tracker.cc"],
hdrs = ["channel_tracker.h"],
deps = [
":hlo",
":session_proto",
":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -1025,7 +953,6 @@ tf_cc_test(
":buffer_assignment",
":buffer_value",
":call_graph",
":computation_tracker",
":copy_insertion",
":cpu_plugin",
":flatten_call_graph",
@ -1039,9 +966,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@ -1077,9 +1004,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -1180,9 +1107,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -1215,9 +1142,22 @@ tf_cc_test(
deps = [
":hlo_matchers",
":instruction_fusion",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
cc_library(
name = "multi_output_fusion",
srcs = ["multi_output_fusion.cc"],
hdrs = ["multi_output_fusion.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
],
)
@ -1389,9 +1329,9 @@ tf_cc_test(
deps = [
":gather_expander",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:test_macros_header",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -1697,14 +1637,11 @@ tf_cc_test(
name = "hlo_cost_analysis_test",
srcs = ["hlo_cost_analysis_test.cc"],
deps = [
":computation_tracker",
":cpu_plugin",
":hlo",
":hlo_cost_analysis",
":local_service",
":service",
":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
@ -1743,9 +1680,9 @@ tf_cc_test(
":cpu_plugin",
":hlo_cost_analysis",
":hlo_execution_profile",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@ -1926,9 +1863,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@ -2045,20 +1982,6 @@ tf_cc_test(
],
)
cc_library(
name = "compilation_cache",
srcs = ["compilation_cache.cc"],
hdrs = ["compilation_cache.h"],
deps = [
":executable",
":hlo_module_config",
":versioned_computation_handle",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
cc_library(
name = "layout_assignment",
srcs = [
@ -2229,6 +2152,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
@ -2263,11 +2187,11 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@ -2289,9 +2213,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@ -2339,6 +2263,7 @@ cc_library(
hdrs = ["hlo_cse.h"],
deps = [
":hlo",
":hlo_domain_map",
":hlo_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@ -2361,10 +2286,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@ -2403,6 +2328,78 @@ tf_cc_test(
],
)
cc_library(
name = "hlo_domain_map",
srcs = ["hlo_domain_map.cc"],
hdrs = ["hlo_domain_map.h"],
deps = [
":hlo",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
cc_library(
name = "hlo_sharding_metadata",
srcs = ["hlo_sharding_metadata.cc"],
hdrs = [
"hlo_sharding_metadata.h",
],
deps = [
":hlo",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
],
)
cc_library(
name = "hlo_domain_isolator",
srcs = ["hlo_domain_isolator.cc"],
hdrs = ["hlo_domain_isolator.h"],
deps = [
":hlo",
":hlo_graph_dumper",
":hlo_pass",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
],
)
cc_library(
name = "hlo_domain_remover",
srcs = ["hlo_domain_remover.cc"],
hdrs = ["hlo_domain_remover.h"],
deps = [
":hlo",
":hlo_domain_isolator",
":hlo_domain_map",
":hlo_graph_dumper",
":hlo_pass",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "hlo_domain_test",
srcs = ["hlo_domain_test.cc"],
deps = [
":hlo",
":hlo_domain_isolator",
":hlo_domain_remover",
":hlo_parser",
":hlo_sharding_metadata",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "hlo_element_type_converter",
srcs = ["hlo_element_type_converter.cc"],
@ -2484,10 +2481,10 @@ xla_test(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -2576,6 +2573,7 @@ cc_library(
hdrs = ["hlo_graph_dumper.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_execution_profile",
":hlo_tfgraph_builder",
"//tensorflow/compiler/xla:literal_util",
@ -2633,10 +2631,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@ -2773,7 +2771,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@ -2809,8 +2807,8 @@ tf_cc_test(
":tuple_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -2833,9 +2831,10 @@ tf_cc_test(
deps = [
":while_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -2861,8 +2860,8 @@ tf_cc_test(
":hlo_matchers",
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@ -2888,8 +2887,8 @@ tf_cc_test(
":hlo_matchers",
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@ -2942,9 +2941,75 @@ tf_cc_test(
":hlo_matchers",
":indexed_array_analysis",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
cc_library(
name = "hlo_parser",
srcs = ["hlo_parser.cc"],
hdrs = ["hlo_parser.h"],
deps = [
":hlo",
":hlo_lexer",
":hlo_sharding_metadata",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "hlo_parser_test",
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
":hlo_parser",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main", # fixdeps: keep
],
)
cc_library(
name = "hlo_lexer",
srcs = ["hlo_lexer.cc"],
hdrs = [
"hlo_lexer.h",
"hlo_token.h",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
],
)
cc_library(
name = "hlo_casting_utils",
hdrs = ["hlo_casting_utils.h"],
deps = ["//tensorflow/core:lib"],
)
tf_cc_test(
name = "hlo_casting_utils_test",
srcs = ["hlo_casting_utils_test.cc"],
deps = [
":hlo",
":hlo_casting_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
)

View File

@ -233,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* operand, HloInstruction* max,
HloInstruction* max_operand);
// A Reshape or Broadcast that feeds an element-wise operation with a unique
// non-scalar operand can sink to after the operation.
StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* reshape_or_broadcast);
// A Broadcast that feeds an element-wise operation with a unique non-scalar
// operand can sink to after the operation.
StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* broadcast);
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
@ -1305,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
// broadcast after the unary element-wise operation.
TF_ASSIGN_OR_RETURN(
bool sink_succeeded,
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
changed_ |= sink_succeeded;
if (sink_succeeded) {
return Status::OK();
@ -1557,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
return Status::OK();
}
StatusOr<bool> AlgebraicSimplifierVisitor::
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* reshape_or_broadcast) {
StatusOr<bool>
AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* broadcast) {
TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
bool changed = false;
if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) {
if (ShapeUtil::IsScalar(broadcast->shape())) {
return false;
}
HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
for (HloInstruction* user : reshape_or_broadcast->users()) {
HloInstruction* operand = broadcast->mutable_operand(0);
for (HloInstruction* user : broadcast->users()) {
if (user->user_count() == 0 && user != computation_->root_instruction()) {
continue;
}
@ -1583,55 +1584,50 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
continue;
}
int64 reshape_or_broadcast_operand_index = -1;
// Find the unique non-scalar operand or continue if there isn't one.
int64 scalar_count = 0;
for (int64 i = 0; i < user->operand_count(); ++i) {
if (ShapeUtil::IsScalar(user->operand(i)->shape())) {
++scalar_count;
} else {
reshape_or_broadcast_operand_index = i;
int64 scalar_broadcast_count = 0;
int64 broadcast_use_count = 0;
for (HloInstruction* user_operand : user->operands()) {
if (user_operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
++scalar_broadcast_count;
} else if (broadcast == user_operand) {
++broadcast_use_count;
}
}
if (scalar_count != user->operand_count() - 1) {
if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
continue;
}
VLOG(4) << "Sinking reshape or broadcast after user:";
VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString();
VLOG(4) << " old user: " << user->ToString();
CHECK_EQ(user->operand(reshape_or_broadcast_operand_index),
reshape_or_broadcast);
auto new_user_operands = user->operands();
new_user_operands[reshape_or_broadcast_operand_index] = operand;
auto new_user = computation_->AddInstruction(user->CloneWithNewOperands(
ShapeUtil::MakeShapeWithLayout(
user->shape().element_type(),
AsInt64Slice(operand->shape().dimensions()),
LayoutUtil::MinorToMajor(operand->shape())),
new_user_operands));
VLOG(4) << " new user: " << new_user->ToString();
HloInstruction* new_reshape_or_broadcast = nullptr;
if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) {
new_reshape_or_broadcast =
computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShapeWithLayout(
user->shape().element_type(),
AsInt64Slice(reshape_or_broadcast->shape().dimensions()),
LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())),
new_user));
} else {
TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast);
new_reshape_or_broadcast =
computation_->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShapeWithLayout(
user->shape().element_type(),
AsInt64Slice(reshape_or_broadcast->shape().dimensions()),
LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())),
new_user, reshape_or_broadcast->dimensions()));
std::vector<HloInstruction*> new_operands;
new_operands.reserve(user->operand_count());
for (HloInstruction* user_operand : user->operands()) {
if (user_operand->opcode() == HloOpcode::kBroadcast &&
ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
new_operands.push_back(
computation_->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(
operand->shape(), user_operand->shape().element_type()),
user_operand->mutable_operand(0), {})));
} else {
CHECK_EQ(broadcast, user_operand);
new_operands.push_back(operand);
}
}
VLOG(4) << " new reshape/broadcast: "
<< new_reshape_or_broadcast->ToString();
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast));
VLOG(4) << "Sinking broadcast after user:";
VLOG(4) << " old broadcast: " << broadcast->ToString();
VLOG(4) << " old user: " << user->ToString();
HloInstruction* new_user =
computation_->AddInstruction(user->CloneWithNewOperands(
ShapeUtil::ChangeElementType(operand->shape(),
user->shape().element_type()),
new_operands));
VLOG(4) << " new user: " << new_user->ToString();
HloInstruction* new_broadcast =
computation_->AddInstruction(HloInstruction::CreateBroadcast(
user->shape(), new_user, broadcast->dimensions()));
VLOG(4) << " new broadcast: " << new_broadcast->ToString();
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
changed = true;
}
return changed;
@ -1674,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
}
}
// A Reshape that feeds a unary element-wise operation can sink the
// reshape after the unary element-wise operation.
TF_ASSIGN_OR_RETURN(
bool sink_succeeded,
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape));
changed_ |= sink_succeeded;
if (sink_succeeded) {
return Status::OK();
}
// Make this a bitcast if possible.
if (is_layout_sensitive_ &&
ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
@ -1788,6 +1774,46 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
new_reduce_dimensions, function));
}
// If the reduction results in the same number of elements, then the only
// possible side effect would be a reshape. Since the init_value is an
// identity of the reduction function, we can therefore replace the reduce
// with a simple reshape, ignoring the reduction function completely.
if (ShapeUtil::ElementsIn(reduce->shape()) ==
ShapeUtil::ElementsIn(arg->shape())) {
return ReplaceWithNewInstruction(
reduce, HloInstruction::CreateReshape(reduce->shape(), arg));
}
// If a reduce feeds a reduce with the same computation and initial value,
// they can be combined into a single reduce.
if (arg->opcode() == HloOpcode::kReduce &&
init_value->Identical(*arg->operand(1)) &&
*function == *arg->to_apply()) {
// Create a new reduce with the combined reduction dimensions of both
// reduces.
std::vector<int64> arg_dims = arg->dimensions();
std::sort(arg_dims.begin(), arg_dims.end());
std::vector<int64> reduce_dims = reduce->dimensions();
std::sort(reduce_dims.begin(), reduce_dims.end());
// Transform reduce_dims to the same rank as the operand of the operand.
for (int64 arg_dim : arg_dims) {
for (int64& dim : reduce_dims) {
if (dim >= arg_dim) {
++dim;
}
}
}
std::vector<int64> new_dimensions;
new_dimensions.reserve(arg->dimensions().size() +
reduce->dimensions().size());
std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
reduce_dims.end(), std::back_inserter(new_dimensions));
return ReplaceWithNewInstruction(
reduce,
HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0),
init_value, new_dimensions, function));
}
// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.
@ -1832,15 +1858,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
new_reduce_dimensions, function));
}
}
if (ShapeUtil::ElementsIn(reduce->shape()) ==
ShapeUtil::ElementsIn(arg->shape()) ||
ShapeUtil::HasZeroElements(arg->shape())) {
auto reshape = computation_->AddInstruction(
HloInstruction::CreateReshape(reduce->shape(), arg));
return ReplaceWithNewInstruction(
reduce, HloInstruction::CreateMap(reduce->shape(),
{init_value, reshape}, function));
}
return Status::OK();
}
@ -1860,7 +1877,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateMap(reduce_window->shape(),
{operand, reduce_window->mutable_operand(1)},
{reduce_window->mutable_operand(1), operand},
function));
}

View File

@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
EXPECT_EQ(root, param0);
}
// Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
HloComputation::Builder builder(TestName());
// Create add computation.
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
HloComputation* add_computation = nullptr;
{
HloComputation::Builder builder(TestName() + ".add");
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
HloInstruction* p0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape, "p0"));
HloInstruction* p1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
add_computation = module().AddEmbeddedComputation(builder.Build());
}
Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r4f32, "param"));
std::vector<int64> dims0({0});
Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
HloInstruction* reduce0 = builder.AddInstruction(
HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
std::vector<int64> dims1({1, 2});
Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
dims1, add_computation));
module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
HloInstruction* root = module().entry_computation()->root_instruction();
EXPECT_THAT(root, op::Reduce(param, zero));
EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
}
// Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@ -1351,32 +1389,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape));
}
TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param"));
HloInstruction* movable_reshape =
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param));
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
HloOpcode::kMaximum, movable_reshape, zero));
auto computation = module().AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Reshape(param), zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
bitcasting_callback());
simplifier.Run(&module()).ValueOrDie();
EXPECT_THAT(computation->root_instruction(),
op::Reshape(op::Maximum(param, zero)));
}
// Regression test for a bug in the reshape sinking transformation, where
// moving a reshape to a scalar led to a crash.
TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
@ -1740,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@ -1785,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@ -1807,7 +1819,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@ -1830,7 +1842,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@ -1958,7 +1970,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
auto module = CreateNewModule();
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
@ -2086,7 +2099,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@ -2116,7 +2129,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@ -2147,7 +2160,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@ -2177,7 +2190,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
@ -2210,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@ -2226,10 +2239,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
HloInstruction* broadcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(
broadcast_shape, scalar_param,
AsInt64Slice(broadcast_shape.dimensions())));
HloInstruction* broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
@ -2245,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
// Running simplification again should not result in any further changes.
ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@ -2263,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
HloInstruction* broadcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(
broadcast_shape, forty_two,
AsInt64Slice(broadcast_shape.dimensions())));
HloInstruction* broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
HloInstruction* transpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
@ -2285,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@ -2294,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
auto module = CreateNewModule();
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@ -2375,7 +2385,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
auto module = CreateNewModule();
// TODO(b/80488902): verify this module.
auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@ -2470,7 +2481,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);

View File

@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
bool rewrite_inference_op, bool rewrite_grad_op,
bool use_fusion);
bool rewrite_inference_op, bool rewrite_grad_op);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@ -70,21 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
explicit BatchNormExpanderVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion)
bool rewrite_grad_op)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
rewrite_grad_op_(rewrite_grad_op) {}
HloComputation* GetOrCreateScalarAddComputation(
PrimitiveType primitive_type) {
HloComputation** scalar_add_computation =
&scalar_add_computations_[primitive_type];
if (*scalar_add_computation) {
return *scalar_add_computation;
}
HloComputation::Builder b("scalar_add_computation");
Shape shape = ShapeUtil::MakeShape(primitive_type, {});
auto scalar_lhs = b.AddInstruction(
@ -93,26 +85,39 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
*scalar_add_computation =
computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
return *scalar_add_computation;
return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
}
// Current HloComputation instance the BatchNormExpander is
// traversing.
HloComputation* computation_;
std::unique_ptr<HloInstruction> Rsqrt(
HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
add_instruction) {
HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast(
operand->shape(),
add_instruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
add_instruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(-0.5f))))),
{}));
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower,
operand, exponent);
}
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
// Whether rewrite has occurred.
bool changed_ = false;
// Cached computations for adding two scalars.
tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
scalar_add_computations_;
std::unique_ptr<HloInstruction> Mean(
int64 element_count, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
add_instruction) {
HloInstruction* elem_count_recip =
add_instruction(HloInstruction::CreateBroadcast(
operand->shape(),
add_instruction(HloInstruction::CreateConvert(
ShapeUtil::MakeShape(operand->shape().element_type(), {}),
add_instruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(1.0 / element_count))))),
{}));
return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply,
operand, elem_count_recip);
}
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
@ -136,6 +141,16 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
changed_ = true;
return Status::OK();
}
// Current HloComputation instance the BatchNormExpander is
// traversing.
HloComputation* computation_;
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
// Whether rewrite has occurred.
bool changed_ = false;
};
} // namespace
@ -143,13 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool BatchNormExpanderVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
bool rewrite_grad_op, bool use_fusion) {
bool rewrite_grad_op) {
BatchNormExpanderVisitor visitor(
computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_inference_op=*/rewrite_inference_op,
/*rewrite_grad_op=*/rewrite_grad_op,
/*use_fusion=*/use_fusion);
/*rewrite_grad_op=*/rewrite_grad_op);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -167,6 +181,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
added_instructions.push_back(added_inst);
return added_inst;
};
auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
HloInstruction* a, HloInstruction* b) {
return add(HloInstruction::CreateBinary(shape, opcode, a, b));
};
int64 instruction_count_before = computation_->instruction_count();
// Expand batch norm training into smaller HLO ops.
@ -176,12 +194,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
int64 feature_index = batch_norm->feature_index();
const int64 feature_count = operand_shape.dimensions(feature_index);
const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
auto elements_per_feature_literal =
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
int64 elements_per_feature_int64 = size_in_elements / feature_count;
HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
@ -193,8 +206,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
@ -213,8 +227,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
GetOrCreateScalarAddComputation(ptype);
// X^2.
auto operand_squared = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kMultiply, operand, operand));
auto operand_squared =
add_binary(operand_shape, HloOpcode::kMultiply, operand, operand);
// Sum[X].
auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
dimensions_without_feature,
@ -225,71 +239,48 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
feature_shape, operand_squared, zero, dimensions_without_feature,
add_reduce_computation));
// Fuse two parallel reduces together to improve performance.
if (use_fusion_ && !batch_norm->has_sharding()) {
auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum}));
auto fused = computation_->CreateFusionInstruction(
{tuple, sum, squared_sum, operand_squared},
HloInstruction::FusionKind::kInput);
sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
squared_sum =
add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
}
// E[X].
auto mean = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kDivide, sum, elements_per_feature));
auto mean = add(Mean(elements_per_feature_int64, sum, add));
auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
// E[X^2].
auto square_mean = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature));
auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add));
// E^2[X].
auto mean_square = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kMultiply, mean, mean));
auto mean_square =
add_binary(feature_shape, HloOpcode::kMultiply, mean, mean);
// Var[X].
auto var = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kSubtract, square_mean, mean_square));
auto var =
add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
auto var_broadcasted =
add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
auto var_add_epsilon = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half =
add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
auto var_add_epsilon =
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
operand, mean_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon].
auto normalized = add(
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon));
auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
auto scaled_normalized = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
normalized, scale_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted));
auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd,
scaled_normalized, offset_broadcasted);
auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
@ -331,8 +322,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(epsilon_literal))),
{}));
std::vector<int64> dimensions_without_feature;
@ -349,6 +343,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
added_instructions.push_back(added_inst);
return added_inst;
};
auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
HloInstruction* a, HloInstruction* b) {
return add(HloInstruction::CreateBinary(shape, opcode, a, b));
};
int64 instruction_count_before = computation_->instruction_count();
auto scale_broadcasted = add(
@ -364,30 +362,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
auto var_add_epsilon = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half =
add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
auto var_add_epsilon =
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
operand, mean_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon].
auto normalized = add(
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon));
auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
auto scaled_normalized = add(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
normalized, scale_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = HloInstruction::CreateBinary(
@ -435,6 +426,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
added_instructions.push_back(added_inst);
return added_inst;
};
auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
HloInstruction* a, HloInstruction* b) {
return add(HloInstruction::CreateBinary(shape, opcode, a, b));
};
int64 instruction_count_before = computation_->instruction_count();
HloInstruction* activation = batch_norm->mutable_operand(0);
@ -450,26 +445,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
const int64 feature_count = activation_shape.dimensions(feature_index);
auto elements_per_feature_literal =
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half =
add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon =
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
auto epsilon_activation = add(
HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
auto epsilon_feature =
add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
std::vector<int64> dimensions_without_feature;
@ -489,26 +478,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
// rsqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kPower,
add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
variance_broadcasted, epsilon)),
neg_half));
auto rsqrt_var_add_epsilon_broadcasted =
add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
variance_broadcasted, epsilon_activation),
add));
auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kPower,
add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance,
epsilon)),
neg_half));
auto rsqrt_var_add_epsilon = add(Rsqrt(
add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
add));
// X - E[X].
auto activation_minus_mean = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted));
auto activation_minus_mean = add_binary(
activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted);
// Grad[Y] * (X - E[X]).
auto grad_output_times_activiation_minus_mean =
add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
grad_output, activation_minus_mean));
add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
activation_minus_mean);
HloComputation* add_reduce_computation =
GetOrCreateScalarAddComputation(ptype);
@ -524,25 +510,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
feature_shape, grad_output, zero, dimensions_without_feature,
add_reduce_computation));
if (use_fusion_ && !batch_norm->has_sharding()) {
auto tuple = add(HloInstruction::CreateTuple(
{sum_grad_output_times_activiation_minus_mean, grad_beta}));
auto fused = computation_->CreateFusionInstruction(
{tuple, sum_grad_output_times_activiation_minus_mean, grad_beta},
HloInstruction::FusionKind::kInput);
sum_grad_output_times_activiation_minus_mean =
add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
grad_beta =
add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
}
// Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
auto grad_scale = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kMultiply,
sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon));
auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
sum_grad_output_times_activiation_minus_mean,
rsqrt_var_add_epsilon);
// I2 = Sum(Grad[Y])
auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
@ -554,39 +525,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
{feature_index}));
// I4 = (X - E[X]) * I3
auto i4 = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean));
auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
activation_minus_mean);
// I5 = I4 / (Var[X] + epsilon)
auto i5 = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kDivide, i4,
add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
variance_broadcasted, epsilon))));
auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
add_binary(activation_shape, HloOpcode::kAdd,
variance_broadcasted, epsilon_activation));
// scale * rsqrt[Var[X] + epsilon] * 1/N
auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kMultiply, scale_broadcasted,
rsqrt_var_add_epsilon_broadcasted));
auto scale_times_rsqrt_var_add_epsilon =
add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
rsqrt_var_add_epsilon_broadcasted);
scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon,
elements_per_feature));
scale_times_rsqrt_var_add_epsilon = add(
Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add));
auto i1 =
add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
grad_output, elements_per_feature));
auto elements_per_feature_literal =
Literal::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
add(HloInstruction::CreateBroadcast(
activation_shape, elements_per_feature, {})));
// I6 = I1 - I2 - I5
auto i6 = add(HloInstruction::CreateBinary(
auto i6 = add_binary(
activation_shape, HloOpcode::kSubtract,
add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract,
i1, i2)),
i5));
add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5);
// Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
auto grad_activation =
add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
scale_times_rsqrt_var_add_epsilon, i6));
auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply,
scale_times_rsqrt_var_add_epsilon, i6);
auto tuple =
HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
if (batch_norm->has_sharding()) {
@ -615,8 +587,8 @@ StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
rewrite_inference_op_, rewrite_grad_op_,
use_fusion_)) {
rewrite_inference_op_,
rewrite_grad_op_)) {
changed = true;
}
}

View File

@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface {
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
bool rewrite_inference_op = false,
bool rewrite_grad_op = false, bool use_fusion = true)
bool rewrite_grad_op = false)
: rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
bool use_fusion_;
};
} // namespace xla

View File

@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto builder = HloComputation::Builder(TestName());
auto module = CreateNewModule();
HloComputation::Builder sum_builder("add");
auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
sum_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}));
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
sum));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* tuple = builder.AddInstruction(
HloInstruction::CreateTuple({gte_a, convert_gte_b}));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(FoldConversions(module.get()));

View File

@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto module = CreateNewModule();
HloComputation::Builder sum_builder("sum");
auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
sum_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
HloComputation* reduction =
module->AddEmbeddedComputation(sum_builder.Build());
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b},
reduction));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));

View File

@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType(
worklist.push_back(std::make_pair(subcomputation,
false)); // Not thread local.
break;
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@ -33,12 +32,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
@ -82,7 +81,7 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
class BufferAssignmentTest : public HloTestBase {
protected:
BufferAssignmentTest() : computation_tracker_() {}
BufferAssignmentTest() {}
~BufferAssignmentTest() override {}
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
@ -252,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase {
return total_size;
}
// Computation tracker for nested computations.
ComputationTracker computation_tracker_;
// Shapes for use in the examples.
Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
@ -375,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) {
// param1[100] --------------/--------/
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, ""));
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, ""));
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@ -422,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
// share anything.
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, ""));
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, ""));
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@ -481,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
// have the color 0, which allows the mul and add to share buffers.
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, ""));
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, ""));
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@ -551,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
//
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, ""));
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, ""));
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@ -605,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) {
// Creates the main kernel and verifies instruction counts.
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32a100x10_, ""));
HloInstruction::CreateParameter(0, f32a100x10_, "p"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
module->AddEntryComputation(builder.Build());
@ -658,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32a100x10_, ""));
HloInstruction::CreateParameter(0, f32a100x10_, "p"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
auto exp2 = builder.AddInstruction(
@ -822,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
// param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32vec100_, ""));
HloInstruction::CreateParameter(0, f32vec100_, "p"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
auto tanh = builder.AddInstruction(
@ -1500,11 +1496,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
// param1[100] --------------/--------/
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32vec100_, ""));
HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32vec100_, ""));
HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@ -1540,7 +1536,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
// be {%rev, %neg, %concat}. This occurs right at the concat itself.
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32vec100_, ""));
HloInstruction::CreateParameter(0, f32vec100_, "p"));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
auto rev = builder.AddInstruction(
@ -1797,7 +1793,7 @@ ENTRY %test_module {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
tools::Parse(module_str));
ParseHloString(module_str));
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.

View File

@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kConditional:
case HloOpcode::kWhile:
return CallContext::kSequential;
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:

View File

@ -19,9 +19,6 @@ limitations under the License.
#include <map>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"

View File

@ -1,78 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/compilation_cache.h"
#include <utility>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
std::shared_ptr<Executable> CompilationCache::Insert(
std::unique_ptr<Executable> executable,
const HloModuleConfig& module_config) {
tensorflow::mutex_lock lock(mutex_);
CacheKey key =
BuildKey(executable->entry_computation_handle(), module_config);
VLOG(2) << "inserting cache key: " << key;
if (cache_.count(key) == 0) {
cache_.emplace(key, std::move(executable));
} else {
// Executable already exists in the cache. This can happen if two Execute
// calls for a new computation are received simultaneously by the
// service. In this case, we discard the Executable given as a parameter and
// return what is in the cache. This is necessary because the service relies
// on the cache to keep ownership of the Executable. We only want to store
// one Executable for a given computation version and we can't discard the
// executable which is in the cache because it may be in use.
executable.reset();
}
return cache_.at(key);
}
std::shared_ptr<Executable> CompilationCache::LookUp(
const VersionedComputationHandle& versioned_handle,
const HloModuleConfig& module_config) const {
tensorflow::mutex_lock lock(mutex_);
CacheKey key = BuildKey(versioned_handle, module_config);
VLOG(2) << "looking up cache key: " << key;
if (cache_.count(key) == 0) {
VLOG(2) << "cache key not found: " << key;
return nullptr;
} else {
std::shared_ptr<Executable> result = cache_.at(key);
VLOG(2) << "hit executable with module config: "
<< result->module_config().compilation_cache_key();
return result;
}
}
CompilationCache::CacheKey CompilationCache::BuildKey(
const VersionedComputationHandle& versioned_handle,
const HloModuleConfig& module_config) const {
// The computation shape is represented entirely by its ProgramShape member,
// so just serialize the proto as part of the key.
return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::",
versioned_handle.version, "::",
module_config.compilation_cache_key());
}
} // namespace xla

View File

@ -1,78 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
#include <map>
#include <memory>
#include <string>
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace xla {
// A cache which stores Executables indexed by computation handle and version.
class CompilationCache {
public:
CompilationCache() {}
// Insert the given Executable into the cache. Return a bare Executable
// pointer for the caller to use. Note: the returned pointer will *not* be the
// same as the given unique pointer if the computation already exists in the
// cache. See comments in the .cc implementation for details of this case.
//
// module_config is provided by the caller, instead of being taken from the
// executable, so that we can insert keys into the compilation cache that are
// devoid of layout (where XLA gets to choose what layout to compile).
//
// A shared_ptr is returned so the caller can keep the Executable from being
// destructed in the event that the Executable is evicted from the
// computation cache (and the cache's shared_ptr to the Executable is
// destructed).
std::shared_ptr<Executable> Insert(std::unique_ptr<Executable> executable,
const HloModuleConfig& module_config);
// Lookup the Executable for the specified versioned computation in the cache.
// Return a shared_ptr to the Executable if it exists in the cache. Return
// nullptr otherwise.
std::shared_ptr<Executable> LookUp(
const VersionedComputationHandle& versioned_handle,
const HloModuleConfig& module_config) const;
protected:
mutable tensorflow::mutex mutex_;
// Map from versioned handle with program layout to Executable built
// for that computation version and program layout.
using CacheKey = string;
CacheKey BuildKey(const VersionedComputationHandle& versioned_handle,
const HloModuleConfig& module_config) const;
std::map<CacheKey, std::shared_ptr<Executable>> cache_ GUARDED_BY(mutex_);
private:
TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@ -104,56 +103,4 @@ CompileOnlyService::CompileAheadOfTime(
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& options) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
for (const AotComputationInstance& instance : computations) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(instance.computation));
VersionedComputationHandle versioned_handle =
user_computation->GetVersionedHandle();
const DebugOptions& debug_options = options.debug_options();
// Dump computation proto state if flag is set.
const string& directory_path = debug_options.xla_dump_computations_to();
if (!directory_path.empty()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<SessionModule> session_module,
computation_tracker_.SnapshotComputation(versioned_handle.handle));
string filename = tensorflow::strings::StrCat(
"computation_", versioned_handle.handle.handle(), "__",
session_module->entry().name(), "__version_",
versioned_handle.version);
const string& per_host_path = tensorflow::io::JoinPath(
directory_path, tensorflow::port::Hostname());
TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename,
*session_module));
}
TF_ASSIGN_OR_RETURN(
std::shared_ptr<const ProgramShape> program_shape,
user_computation->ComputeProgramShape(versioned_handle.version));
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, instance.argument_layouts,
&execution_options, user_computation));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
versioned_handle, *module_config,
/*include_unreachable_instructions=*/true));
TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
}
} // namespace xla

View File

@ -38,24 +38,7 @@ class CompileOnlyService : public Service {
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
const ServiceOptions& options);
// A description of a computation to compile using CompileAheadOfTime.
struct AotComputationInstance {
ComputationHandle computation;
std::vector<const Shape*> argument_layouts;
const Shape* result_layout = nullptr;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
const AotCompilationOptions& Options);
// A description of a xla computation to compile using CompileAheadOfTime.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
struct AotXlaComputationInstance {
HloModuleProto computation;
std::vector<const Shape*> argument_layouts;
@ -65,31 +48,15 @@ class CompileOnlyService : public Service {
// Compiles a list of xla computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
//
// TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
// Override Service methods that require or imply the existence of an
// execute backend. Note that this does not include TransferToClient, as
// computing constants produces global data that we may wish to transfer.
Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
Status ExecuteParallel(const ExecuteParallelRequest* arg,
ExecuteParallelResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
}
Status ExecuteAsync(const ExecuteAsyncRequest* arg,
ExecuteAsyncResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
}
Status WaitForExecution(const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");

View File

@ -28,8 +28,9 @@ namespace xla {
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
tensorflow::LINKER_INITIALIZED);
std::vector<string> Compiler::ComputeBackendConfigs(
const HloInstruction& hlo, se::StreamExecutor* executor) const {
std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
se::StreamExecutor* executor) const {
CHECK(executor != nullptr);
return {};
}

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@ -161,8 +162,9 @@ class Compiler {
//
// The stream executor is passed in to provide information about the hardware
// that the backend configurations would be targeting.
virtual std::vector<string> ComputeBackendConfigs(
const HloInstruction& hlo, se::StreamExecutor* executor) const;
virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
ComputeBackendConfigs(const HloInstruction& hlo,
se::StreamExecutor* executor) const;
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.

View File

@ -32,12 +32,21 @@ namespace xla {
// mutable layouts.
class ComputationLayout {
public:
// Creates a new ComputationLayout with the given result layout.
explicit ComputationLayout(ShapeLayout result_layout)
: result_layout_(std::move(result_layout)) {}
// Constructs a ComputationLayout from a ProgramShape. The layouts of the
// parameters and results are set to the default layout. Layouts in the
// ProgramShape are ignored if ignore_layouts is true.
explicit ComputationLayout(const ProgramShape& program_shape,
bool ignore_layouts = true);
// Adds a new parameter layout to the computation layout.
void add_parameter_layout(ShapeLayout shape_layout) {
parameter_layouts_.push_back(std::move(shape_layout));
}
// Returns the layout of a particular parameter.
const ShapeLayout& parameter_layout(int64 param_no) const {
return parameter_layouts_[param_no];

View File

@ -1,256 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include <list>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
using ::tensorflow::strings::Appendf;
namespace xla {
ComputationTracker::ComputationTracker() : next_computation_(1) {}
ComputationHandle ComputationTracker::NewComputation(
const string& computation_name) {
tensorflow::mutex_lock lock(computation_mutex_);
ComputationHandle computation_handle;
int64 handle_value = next_computation_++;
computation_handle.set_handle(handle_value);
opaque_to_computation_[handle_value] =
MakeUnique<UserComputation>(computation_name, computation_handle);
return computation_handle;
}
StatusOr<ComputationHandle> ComputationTracker::LoadSessionModule(
const SessionModule& session_module) {
tensorflow::mutex_lock lock(computation_mutex_);
// For each embedded computation, create a new computation based on its
// serialized data, and place the mapping from the old computation handle to
// the new computation handle.
// Build a mapping from old embedded computation handles to new computation
// handles. We build the ID mapping first since the embedded computations are
// in no particular order and may refer to each other.
std::map<int64, ComputationHandle> old_to_new;
for (const SessionComputation& computation :
session_module.embedded_computations()) {
const int64 old_handle = computation.computation_handle().handle();
if (!old_to_new.emplace(old_handle, AllocateHandle()).second) {
return InvalidArgument("Duplicate embedded computation handle %lld",
old_handle);
}
}
// Create a new computation from each serialized embedded computation.
for (const SessionComputation& computation :
session_module.embedded_computations()) {
const int64 old_handle = computation.computation_handle().handle();
const ComputationHandle& new_handle = old_to_new[old_handle];
TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
UserComputation::MakeWithRemapping(
computation, new_handle, old_to_new));
}
// Finally, place the entry computation in the tracker with all of the
// remappings populated from the above.
const int64 old_handle = session_module.entry().computation_handle().handle();
TF_ASSIGN_OR_RETURN(
old_to_new[old_handle],
LoadSessionComputation(session_module.entry(), &old_to_new));
return old_to_new[old_handle];
}
StatusOr<std::unique_ptr<SessionModule>>
ComputationTracker::SnapshotComputation(const ComputationHandle& computation) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation));
const VersionedComputationHandle entry_versioned_handle =
user_computation->GetVersionedHandle();
std::set<VersionedComputationHandle> visited;
std::list<VersionedComputationHandle> post_order;
{
tensorflow::mutex_lock lock(computation_mutex_);
ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order);
}
auto session_module = MakeUnique<SessionModule>();
*session_module->mutable_entry() =
Resolve(entry_versioned_handle.handle)
.ValueOrDie()
->CloneSessionComputation(entry_versioned_handle.version);
for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) {
*session_module->add_embedded_computations() =
Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version);
}
return std::move(session_module);
}
StatusOr<UserComputation*> ComputationTracker::Resolve(
const ComputationHandle& computation) const {
tensorflow::mutex_lock lock(computation_mutex_);
return ResolveInternal(computation);
}
ComputationHandle ComputationTracker::AllocateHandle() {
int64 handle_value = next_computation_++;
ComputationHandle result;
result.set_handle(handle_value);
return result;
}
StatusOr<ComputationHandle> ComputationTracker::LoadSessionComputation(
const SessionComputation& session_computation,
std::map<int64, ComputationHandle>* old_to_new) {
TF_RET_CHECK(old_to_new != nullptr);
const ComputationHandle new_handle = AllocateHandle();
(*old_to_new)[session_computation.computation_handle().handle()] = new_handle;
TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
UserComputation::MakeWithRemapping(
session_computation, new_handle, *old_to_new));
return new_handle;
}
StatusOr<UserComputation*> ComputationTracker::ResolveInternal(
const ComputationHandle& computation) const {
auto it = opaque_to_computation_.find(computation.handle());
if (it == opaque_to_computation_.end()) {
return NotFound("computation handle not found: %lld", computation.handle());
}
UserComputation* user_computation = it->second.get();
return user_computation;
}
void ComputationTracker::ComputeComputationPostOrder(
const VersionedComputationHandle& versioned_handle,
std::set<VersionedComputationHandle>* visited,
std::list<VersionedComputationHandle>* post_order) const {
if (visited->count(versioned_handle) > 0) {
CHECK_EQ(1, visited->count(versioned_handle));
return;
}
UserComputation* computation =
ResolveInternal(versioned_handle.handle).ValueOrDie();
std::vector<VersionedComputationHandle> embedded_handles =
computation->GetEmbeddedComputations(versioned_handle.version);
for (const auto& embedded_handle : embedded_handles) {
ComputeComputationPostOrder(embedded_handle, visited, post_order);
}
visited->insert(versioned_handle);
post_order->push_back(versioned_handle);
}
StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
const VersionedComputationHandle& entry_handle,
const HloModuleConfig& config,
bool include_unreachable_instructions) const {
tensorflow::mutex_lock lock(computation_mutex_);
VLOG(1) << "BuildHloModule(" << entry_handle
<< ", include_unreachable_instructions="
<< include_unreachable_instructions << ")";
XLA_VLOG_LINES(1, ToStringInternal());
TF_ASSIGN_OR_RETURN(UserComputation * entry_computation,
ResolveInternal(entry_handle.handle));
// Build a topological sort of the entry and any embedded computations as a
// list. The root of the computation will be the last element in the list.
std::set<VersionedComputationHandle> visited;
std::list<VersionedComputationHandle> post_order;
ComputeComputationPostOrder(entry_handle, &visited, &post_order);
// Map from ComputationHandle value and computation version to HloComputation.
std::map<VersionedComputationHandle, HloComputation*> hlo_computations;
// The resolver lambda resolves VersionedHandles to embedded
// HloComputation*. This is required by UserComputation::BuildHloComputation
// when lowering calling operations (map, reduce etc).
auto resolver = [&hlo_computations](
const VersionedComputationHandle& versioned_handle) -> HloComputation* {
CHECK_GT(hlo_computations.count(versioned_handle), 0);
return hlo_computations.at(versioned_handle);
};
// Print the post-order list for this entry computation.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Visiting UserComputations in post order:";
for (const VersionedComputationHandle& versioned_handle : post_order) {
VLOG(2) << " " << versioned_handle;
}
}
string module_name =
tensorflow::strings::StrCat(entry_computation->name(), "_module");
auto module = MakeUnique<HloModule>(module_name, entry_handle, config);
for (auto versioned_handle : post_order) {
UserComputation* computation =
ResolveInternal(versioned_handle.handle).ValueOrDie();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> hlo_computation,
computation->BuildHloComputation(versioned_handle.version, resolver,
config.debug_options(),
include_unreachable_instructions));
// Add the newly created computation to VersionedHandle-to-HloComputation
// map.
DCHECK_EQ(0, hlo_computations.count(versioned_handle));
hlo_computations[versioned_handle] = hlo_computation.get();
if (computation == entry_computation) {
module->AddEntryComputation(std::move(hlo_computation));
} else {
module->AddEmbeddedComputation(std::move(hlo_computation));
}
}
return std::move(module);
}
string ComputationTracker::ToString() const {
tensorflow::mutex_lock lock(computation_mutex_);
return ToStringInternal();
}
string ComputationTracker::ToStringInternal() const {
string out;
Appendf(&out, "ComputationTracker(%p):\n", this);
for (const auto& handle_computation : opaque_to_computation_) {
int64 handle = handle_computation.first;
const std::unique_ptr<UserComputation>& computation =
handle_computation.second;
Appendf(&out, " %4lld : %s \"%s\"\n", handle,
computation->GetVersionedHandle().ToString().c_str(),
computation->name().c_str());
}
return out;
}
} // namespace xla

View File

@ -1,147 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
#include <list>
#include <map>
#include <memory>
#include <set>
#include <string>
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// Tracks computations for the XLA service; computations can be registered
// with a UserComputation instance and can be resolved from a handle for later
// use.
//
// This class is also capable of serializing/deserializing computations that it
// tracks (and to serialize properly you need to serialize all referred-to
// computations as well).
class ComputationTracker {
public:
ComputationTracker();
// Creates a new UserComputation object and returns the corresponding
// ComputationHandle for it.
//
// Precondition: user_computation is not already present in the map.
ComputationHandle NewComputation(const string& computation_name);
// Restores session data for a computation that has been serialized, and
// allocates a new computation handle for it.
StatusOr<ComputationHandle> LoadSessionModule(
const SessionModule& session_module);
// Snapshots a computation (referenced by the provided handle) at its latest
// version, returning a module where it is the entry, and any referred-to
// computations are entrained as "embedded" (non-entry) computations.
StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation(
const ComputationHandle& computation);
// Resolves a ComputationHandle to a UserComputation that is present in the
// map.
StatusOr<UserComputation*> Resolve(
const ComputationHandle& computation) const;
// Builds an HLO module using the specified computation as the entry. The
// module will include the entry computation as well as all computations which
// are called directly or indirectly from the entry computation via operations
// like "map". config is the HLO module configuration to use for the
// constructed module.
// If include_unreachable_instructions is true, then instructions
// which are not reachable from the root are lowered into HloInstructions
// including unreachable parameters. This ensures the entry HloComputation has
// the same program shape (ProgramShape) as the entry UserComputation.
StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
const VersionedComputationHandle& entry_handle,
const HloModuleConfig& config,
bool include_unreachable_instructions = true) const;
string ToString() const;
private:
// Bumps the next_computation_ number and returns the allocated number wrapped
// in a ComputationHandle.
ComputationHandle AllocateHandle()
EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
// Loads a session computation into a UserComputation, registers it, and
// returns the computation handle of the registered computation. If old_to_new
// is provided, it is used for remapping references to computations present in
// session_computation.
//
// old_to_new will be updated with the mapping from session_computation's old
// handle to the returned handle value, and may not be null.
StatusOr<ComputationHandle> LoadSessionComputation(
const SessionComputation& session_computation,
std::map<int64, ComputationHandle>* old_to_new)
EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
// Internal implementation of Resolve method which requires, but does not
// acquire the mutex.
StatusOr<UserComputation*> ResolveInternal(
const ComputationHandle& computation) const
EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
// Builds a post order sort of a computation ("entry") and all of its embedded
// computations including all transitively embedded computations. An embedded
// computation (the callee) will always appear in the sort before the
// computation which calls the embedded computation (the caller). Necessarily,
// the entry computation is the last element in the sort. visited and
// post_order should be empty when calling. post_order contains the post order
// sort when the function return.
void ComputeComputationPostOrder(
const VersionedComputationHandle& versioned_handle,
std::set<VersionedComputationHandle>* visited,
std::list<VersionedComputationHandle>* post_order) const
EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
// Guards the computation mapping. Marked mutable so that the Resolve method
// can remain const; Resolve does't really modify the tracker in any way, but
// it has to lock the mutex for safety.
mutable tensorflow::mutex computation_mutex_;
// The next sequence number to assign to a computation, guarded by the same
// mutex as the mapping as they'll be mutated at the same time.
int64 next_computation_ GUARDED_BY(computation_mutex_);
// Mapping from ComputationHandle value to the corresponding registered
// UserComputation object.
std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_
GUARDED_BY(computation_mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_

View File

@ -1636,8 +1636,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) {
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
config);
HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_SequentialWhiles");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@ -1677,8 +1676,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) {
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
config);
HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_ParallelWhiles");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@ -1750,8 +1748,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
for (int i = 0; i < num_iters; ++i) {
auto builder = HloComputation::Builder("BM_ParallelWhiles");
HloModule module("BM_ManyElementTuple", VersionedComputationHandle(),
config);
HloModule module("BM_ManyElementTuple", config);
for (int j = 0; j < num_tuple_inputs; ++j) {
tuple_params[j] = builder.AddInstruction(
HloInstruction::CreateParameter(j, element_shape, ""));

View File

@ -649,10 +649,10 @@ tf_cc_test(
deps = [
":cpu_instruction_fusion",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@ -706,9 +706,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@ -898,6 +898,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:support",
],
@ -958,7 +959,7 @@ tf_cc_test(
":ir_emission_utils",
":target_machine_features_fake",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)

View File

@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
/*rewrite_grad_op=*/true);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; },

Some files were not shown because too many files have changed in this diff Show More