Ambiq squashed commits

This commit is contained in:
Steve Nesae 2018-11-06 12:47:02 -06:00
parent d391ba441b
commit 9caf68cf7b
3376 changed files with 170761 additions and 65412 deletions

View File

@ -0,0 +1,24 @@
---
name: TensorFlow Lite Op Request
about: Use this template for reporting ops you are using or missing.
---
**System information**
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
- TensorFlow installed from (source or binary):
- TensorFlow version (or github SHA if from source):
**Provide the text output from tflite_convert**
```
# Copy and paste here
```
Also, please include a link to a GraphDef or the model if possible.
**Any other info / logs**
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

View File

@ -14,7 +14,8 @@ data flow graphs. The graph nodes represent mathematical operations, while
the graph edges represent the multidimensional data arrays (tensors) that flow the graph edges represent the multidimensional data arrays (tensors) that flow
between them. This flexible architecture enables you to deploy computation to one between them. This flexible architecture enables you to deploy computation to one
or more CPUs or GPUs in a desktop, server, or mobile device without rewriting or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit. code. TensorFlow also includes [TensorBoard](https://github.com/tensorflow/tensorboard),
a data visualization toolkit.
TensorFlow was originally developed by researchers and engineers TensorFlow was originally developed by researchers and engineers
working on the Google Brain team within Google's Machine Intelligence Research working on the Google Brain team within Google's Machine Intelligence Research
@ -111,7 +112,7 @@ The TensorFlow project strives to abide by generally accepted best practices in
Build Type | Status | Artifacts Build Type | Status | Artifacts
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------
**IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA
**IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | TBA
**IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **IBM ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/)
**IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) **IBM ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | [Release](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/)
**Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/)
@ -127,6 +128,7 @@ Build Type
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) * [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.

View File

@ -1,5 +1,7 @@
workspace(name = "org_tensorflow") workspace(name = "org_tensorflow")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive( http_archive(
name = "io_bazel_rules_closure", name = "io_bazel_rules_closure",
sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae",
@ -57,9 +59,9 @@ android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl. # Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace() tf_workspace()
new_http_archive( http_archive(
name = "inception_v1", name = "inception_v1",
build_file = "models.BUILD", build_file = "//:models.BUILD",
sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105",
urls = [ urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip",
@ -67,9 +69,9 @@ new_http_archive(
], ],
) )
new_http_archive( http_archive(
name = "mobile_ssd", name = "mobile_ssd",
build_file = "models.BUILD", build_file = "//:models.BUILD",
sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
urls = [ urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
@ -77,9 +79,9 @@ new_http_archive(
], ],
) )
new_http_archive( http_archive(
name = "mobile_multibox", name = "mobile_multibox",
build_file = "models.BUILD", build_file = "//:models.BUILD",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
urls = [ urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip",
@ -87,9 +89,9 @@ new_http_archive(
], ],
) )
new_http_archive( http_archive(
name = "stylize", name = "stylize",
build_file = "models.BUILD", build_file = "//:models.BUILD",
sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa",
urls = [ urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
@ -97,9 +99,9 @@ new_http_archive(
], ],
) )
new_http_archive( http_archive(
name = "speech_commands", name = "speech_commands",
build_file = "models.BUILD", build_file = "//:models.BUILD",
sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c",
urls = [ urls = [
"http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip",

View File

@ -238,6 +238,13 @@ def setup_python(environ_cp):
write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path)
environ_cp['PYTHON_BIN_PATH'] = python_bin_path environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# If choosen python_lib_path is from a path specified in the PYTHONPATH
# variable, need to tell bazel to include PYTHONPATH
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
if python_lib_path in python_paths:
write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
# Write tools/python_bin_path.sh # Write tools/python_bin_path.sh
with open( with open(
os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
@ -445,11 +452,12 @@ def convert_version_to_int(version):
return int(version_str) return int(version_str)
def check_bazel_version(min_version): def check_bazel_version(min_version, max_version):
"""Check installed bazel version is at least min_version. """Check installed bazel version is between min_version and max_version.
Args: Args:
min_version: string for minimum bazel version. min_version: string for minimum bazel version.
max_version: string for maximum bazel version.
Returns: Returns:
The bazel version detected. The bazel version detected.
@ -467,6 +475,7 @@ def check_bazel_version(min_version):
min_version_int = convert_version_to_int(min_version) min_version_int = convert_version_to_int(min_version)
curr_version_int = convert_version_to_int(curr_version) curr_version_int = convert_version_to_int(curr_version)
max_version_int = convert_version_to_int(max_version)
# Check if current bazel version can be detected properly. # Check if current bazel version can be detected properly.
if not curr_version_int: if not curr_version_int:
@ -480,6 +489,10 @@ def check_bazel_version(min_version):
print('Please upgrade your bazel installation to version %s or higher to ' print('Please upgrade your bazel installation to version %s or higher to '
'build TensorFlow!' % min_version) 'build TensorFlow!' % min_version)
sys.exit(0) sys.exit(0)
if curr_version_int > max_version_int:
print('Please downgrade your bazel installation to version %s or lower to '
'build TensorFlow!' % max_version)
sys.exit(0)
return curr_version return curr_version
@ -859,7 +872,7 @@ def set_tf_cuda_version(environ_cp):
cuda_toolkit_paths_full = [ cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
] ]
if any([os.path.exists(x) for x in cuda_toolkit_paths_full]): if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break break
# Reset and retry # Reset and retry
@ -1552,7 +1565,7 @@ def main():
# environment variables. # environment variables.
environ_cp = dict(os.environ) environ_cp = dict(os.environ)
check_bazel_version('0.15.0') check_bazel_version('0.15.0', '0.20.0')
reset_tf_configure_bazelrc() reset_tf_configure_bazelrc()
# Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later # Explicitly import tools/bazel.rc, this is needed for Bazel 0.19.0 or later
@ -1694,6 +1707,7 @@ def main():
config_info_line('nohdfs', 'Disable HDFS support.') config_info_line('nohdfs', 'Disable HDFS support.')
config_info_line('noignite', 'Disable Apacha Ignite support.') config_info_line('noignite', 'Disable Apacha Ignite support.')
config_info_line('nokafka', 'Disable Apache Kafka support.') config_info_line('nokafka', 'Disable Apache Kafka support.')
config_info_line('nonccl', 'Disable NVIDIA NCCL support.')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -43,6 +43,11 @@ TENSORFLOW_API_INIT_FILES_V2 = (
TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
) )
# @unused
TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT = (
TENSORFLOW_API_INIT_FILES_V1 + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
)
# Config setting used when building for products # Config setting used when building for products
# which requires restricted licenses to be avoided. # which requires restricted licenses to be avoided.
config_setting( config_setting(
@ -213,31 +218,37 @@ config_setting(
# #
config_setting( config_setting(
name = "no_aws_support", name = "no_aws_support",
define_values = {"no_aws_support": "false"}, define_values = {"no_aws_support": "true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "no_gcp_support", name = "no_gcp_support",
define_values = {"no_gcp_support": "false"}, define_values = {"no_gcp_support": "true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "no_hdfs_support", name = "no_hdfs_support",
define_values = {"no_hdfs_support": "false"}, define_values = {"no_hdfs_support": "true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "no_ignite_support", name = "no_ignite_support",
define_values = {"no_ignite_support": "false"}, define_values = {"no_ignite_support": "true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting( config_setting(
name = "no_kafka_support", name = "no_kafka_support",
define_values = {"no_kafka_support": "false"}, define_values = {"no_kafka_support": "true"},
visibility = ["//visibility:public"],
)
config_setting(
name = "no_nccl_support",
define_values = {"no_nccl_support": "true"},
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
@ -350,7 +361,7 @@ package_group(
"-//third_party/tensorflow/python/estimator", "-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...", "//learning/meta_rank/...",
"//tensorflow/...", "//tensorflow/...",
"//tensorflow_estimator/...", "//tensorflow_estimator/contrib/...",
"//tensorflow_fold/llgtm/...", "//tensorflow_fold/llgtm/...",
"//tensorflow_text/...", "//tensorflow_text/...",
"//third_party/py/tensor2tensor/...", "//third_party/py/tensor2tensor/...",
@ -554,18 +565,24 @@ genrule(
}), }),
outs = ["__init__.py"], outs = ["__init__.py"],
cmd = select({ cmd = select({
"api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", "api_version_2": "cp $(@D)/_api/v2/v2.py $(OUTS)",
"//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", "//conditions:default": "cp $(@D)/_api/v1/v1.py $(OUTS)",
}), }),
) )
gen_api_init_files( gen_api_init_files(
name = "tf_python_api_gen_v1", name = "tf_python_api_gen_v1",
srcs = ["api_template_v1.__init__.py"], srcs = [
"api_template_v1.__init__.py",
"compat_template_v1.__init__.py",
],
api_version = 1, api_version = 1,
compat_api_versions = [1],
compat_init_templates = ["compat_template_v1.__init__.py"],
output_dir = "_api/v1/", output_dir = "_api/v1/",
output_files = TENSORFLOW_API_INIT_FILES_V1, output_files = TENSORFLOW_API_INIT_FILES_V1_WITH_COMPAT,
output_package = "tensorflow._api.v1", output_package = "tensorflow._api.v1",
root_file_name = "v1.py",
root_init_template = "api_template_v1.__init__.py", root_init_template = "api_template_v1.__init__.py",
) )
@ -581,6 +598,7 @@ gen_api_init_files(
output_dir = "_api/v2/", output_dir = "_api/v2/",
output_files = TENSORFLOW_API_INIT_FILES_V2, output_files = TENSORFLOW_API_INIT_FILES_V2,
output_package = "tensorflow._api.v2", output_package = "tensorflow._api.v2",
root_file_name = "v2.py",
root_init_template = "api_template.__init__.py", root_init_template = "api_template.__init__.py",
) )

View File

@ -21,8 +21,6 @@ from __future__ import print_function as _print_function
import os as _os import os as _os
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.python.tools import component_api_helper as _component_api_helper from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook( _component_api_helper.package_hook(
parent_package_str=__name__, parent_package_str=__name__,
@ -30,16 +28,16 @@ _component_api_helper.package_hook(
# API IMPORTS PLACEHOLDER # API IMPORTS PLACEHOLDER
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
# Make sure directory containing top level submodules is in # Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works. # the __path__ so that "from tensorflow.foo import bar" works.
_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable # We're using bitwise, but there's nothing special about that.
_tf_api_dir = _os.path.dirname(_os.path.dirname(bitwise.__file__)) # pylint: disable=undefined-variable
if _tf_api_dir not in __path__: if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir) __path__.append(_tf_api_dir)
# Calls to enable and disable features. # Enable TF2 behaviors
enable_eager_execution() # pylint: disable=undefined-variable from tensorflow.python.compat import compat as _compat # pylint: disable=g-import-not-at-top
_compat.enable_v2_behavior()
# These symbols appear because we import the python package which # These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They # in turn imports from tensorflow.core and tensorflow.python. They

View File

@ -60,6 +60,7 @@ tf_cuda_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:op_gen_lib", "//tensorflow/core:op_gen_lib",
"//tensorflow/core/distributed_runtime:server_lib",
], ],
}), }),
) )
@ -120,7 +121,8 @@ tf_cuda_library(
":c_api", ":c_api",
":c_api_internal", ":c_api_internal",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/c/eager:c_api_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/contrib/tpu:all_ops", "//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
@ -173,6 +175,30 @@ tf_cuda_library(
], ],
) )
tf_cuda_library(
name = "kernels",
srcs = [
"kernels.cc",
],
hdrs = [
"kernels.h",
],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
":c_api",
":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api",
":c_api_internal",
"//tensorflow/core:framework",
],
}),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tests # Tests
@ -208,7 +234,10 @@ tf_cuda_cc_test(
"//tensorflow:darwin": ["-headerpad_max_install_names"], "//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [], "//conditions:default": [],
}), }),
tags = ["noasan"], tags = [
"no_oss", # http://b/119522529
"noasan",
],
# We must ensure that the dependencies can be dynamically linked since # We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework. # the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(), # linkstatic = tf_kernel_tests_linkstatic(),
@ -237,7 +266,7 @@ tf_cuda_cc_test(
tf_cc_test( tf_cc_test(
name = "c_api_experimental_test", name = "c_api_experimental_test",
size = "small", size = "medium",
srcs = ["c_api_experimental_test.cc"], srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"], data = ["testdata/tf_record"],
linkopts = select({ linkopts = select({
@ -248,8 +277,11 @@ tf_cc_test(
# the shared library must be able to use core:framework. # the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(), # linkstatic = tf_kernel_tests_linkstatic(),
deps = [ deps = [
":c_api",
":c_api_experimental", ":c_api_experimental",
":c_test_util", ":c_test_util",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
@ -300,6 +332,30 @@ tf_kernel_library(
alwayslink = 1, alwayslink = 1,
) )
tf_cuda_cc_test(
name = "kernels_test",
size = "small",
srcs = ["kernels_test.cc"],
linkopts = select({
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":c_api",
":kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Python API target # Python API target

View File

@ -15,13 +15,18 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -51,8 +56,8 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
// These XLA flags are needed to trigger XLA properly from C (more generally // These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to // non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is. // false, it is safe to keep these flag values as is.
tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true; flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1; flags->tf_xla_min_cluster_size = 1;
} else { } else {
@ -71,8 +76,8 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
// These XLA flags are needed to trigger XLA properly from C (more generally // These XLA flags are needed to trigger XLA properly from C (more generally
// non-Python) clients. If this API is called again with `enable` set to // non-Python) clients. If this API is called again with `enable` set to
// false, it is safe to keep these flag values as is. // false, it is safe to keep these flag values as is.
tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = tensorflow::MarkForCompilationPassFlags* flags =
tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); tensorflow::GetMarkForCompilationPassFlags();
flags->tf_xla_cpu_global_jit = true; flags->tf_xla_cpu_global_jit = true;
flags->tf_xla_min_cluster_size = 1; flags->tf_xla_min_cluster_size = 1;
} else { } else {
@ -6525,7 +6530,7 @@ library {
} }
} }
node_def { node_def {
name: "ParallelInterleaveDataset/cycle_length" name: "ExperimentalParallelInterleaveDataset/cycle_length"
op: "Const" op: "Const"
attr { attr {
key: "dtype" key: "dtype"
@ -6546,7 +6551,7 @@ library {
} }
} }
node_def { node_def {
name: "ParallelInterleaveDataset/block_length" name: "ExperimentalParallelInterleaveDataset/block_length"
op: "Const" op: "Const"
attr { attr {
key: "dtype" key: "dtype"
@ -6567,7 +6572,7 @@ library {
} }
} }
node_def { node_def {
name: "ParallelInterleaveDataset/sloppy" name: "ExperimentalParallelInterleaveDataset/sloppy"
op: "Const" op: "Const"
attr { attr {
key: "dtype" key: "dtype"
@ -6588,7 +6593,7 @@ library {
} }
} }
node_def { node_def {
name: "ParallelInterleaveDataset/buffer_output_elements" name: "ExperimentalParallelInterleaveDataset/buffer_output_elements"
op: "Const" op: "Const"
attr { attr {
key: "dtype" key: "dtype"
@ -6609,7 +6614,7 @@ library {
} }
} }
node_def { node_def {
name: "ParallelInterleaveDataset/prefetch_input_elements" name: "ExperimentalParallelInterleaveDataset/prefetch_input_elements"
op: "Const" op: "Const"
attr { attr {
key: "dtype" key: "dtype"
@ -6630,14 +6635,14 @@ library {
} }
} }
node_def { node_def {
name: "ParallelInterleaveDataset" name: "ExperimentalParallelInterleaveDataset"
op: "ParallelInterleaveDataset" op: "ExperimentalParallelInterleaveDataset"
input: "RepeatDataset:handle:0" input: "RepeatDataset:handle:0"
input: "ParallelInterleaveDataset/cycle_length:output:0" input: "ExperimentalParallelInterleaveDataset/cycle_length:output:0"
input: "ParallelInterleaveDataset/block_length:output:0" input: "ExperimentalParallelInterleaveDataset/block_length:output:0"
input: "ParallelInterleaveDataset/sloppy:output:0" input: "ExperimentalParallelInterleaveDataset/sloppy:output:0"
input: "ParallelInterleaveDataset/buffer_output_elements:output:0" input: "ExperimentalParallelInterleaveDataset/buffer_output_elements:output:0"
input: "ParallelInterleaveDataset/prefetch_input_elements:output:0" input: "ExperimentalParallelInterleaveDataset/prefetch_input_elements:output:0"
attr { attr {
key: "Targuments" key: "Targuments"
value { value {
@ -6737,7 +6742,7 @@ library {
node_def { node_def {
name: "ShuffleDataset_2" name: "ShuffleDataset_2"
op: "ShuffleDataset" op: "ShuffleDataset"
input: "ParallelInterleaveDataset:handle:0" input: "ExperimentalParallelInterleaveDataset:handle:0"
input: "ShuffleDataset_2/buffer_size_1:output:0" input: "ShuffleDataset_2/buffer_size_1:output:0"
input: "ShuffleDataset_2/seed_2:output:0" input: "ShuffleDataset_2/seed_2:output:0"
input: "ShuffleDataset_2/seed2_2:output:0" input: "ShuffleDataset_2/seed2_2:output:0"
@ -8739,14 +8744,65 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, struct TFE_ExecuteOpNotification {
const char* errMsg) { TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
tensorflow::Notification n;
std::unique_ptr<tensorflow::Thread> thread;
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
};
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
TFE_TensorHandle** retvals,
int* num_retvals,
TF_Status* status) {
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
tensorflow::ThreadOptions(), "ExecuteOpThread",
[op, retvals, num_retvals, n]() {
TFE_Execute(op, retvals, num_retvals, n->status.get());
n->n.Notify();
}));
return n;
}
void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status) {
if (notification == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification is a nullptr.");
return;
}
if (notification->thread == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"Passed in notification didn't start a thread correctly. Cleaning up "
"this notification. Please re-execute the operation to get a new "
"notification.");
delete notification;
return;
}
notification->n.WaitForNotification();
status->status = notification->status->status;
delete notification;
}
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg); status->status = tensorflow::errors::Internal(errMsg);
} }
// This builder is used in the eager API to build a NodeDef. // This builder is used in the eager API to build a NodeDef.
struct TF_AttrBuilder : public tensorflow::AttrBuilder { struct TF_AttrBuilder : public tensorflow::AttrBuilder {
using tensorflow::AttrBuilder::AttrBuilder; using tensorflow::AttrBuilder::AttrBuilder;
// The string buffers to make sure that any `attr_name` we pass into
// `builder->Set()` will outlive the subsequent
// `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
std::set<std::string> attr_names;
}; };
TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) { TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
@ -8757,13 +8813,15 @@ void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name, void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
TF_DataType value) { TF_DataType value) {
builder->Set(attr_name, static_cast<tensorflow::DataType>(value)); auto iter = builder->attr_names.insert(attr_name).first;
builder->Set((*iter).c_str(), static_cast<tensorflow::DataType>(value));
} }
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name, void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
const TF_DataType* values, int num_values) { const TF_DataType* values, int num_values) {
auto iter = builder->attr_names.insert(attr_name).first;
builder->Set( builder->Set(
attr_name, (*iter).c_str(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>( tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
reinterpret_cast<const tensorflow::DataType*>(values), num_values)); reinterpret_cast<const tensorflow::DataType*>(values), num_values));
} }
@ -8800,3 +8858,31 @@ const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
// The returned string is owned by OpRegistry, so liveness is not a concern. // The returned string is owned by OpRegistry, so liveness is not a concern.
return input_arg.number_attr().c_str(); return input_arg.number_attr().c_str();
} }
int TF_OpIsStateful(const char* op_type, TF_Status* status) {
const tensorflow::OpRegistrationData* op_reg_data;
status->status =
tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
if (!status->status.ok()) {
return 0;
}
return op_reg_data->op_def.is_stateful();
}
void TF_InitMain(const char* usage, int* argc, char*** argv) {
tensorflow::port::InitMain(usage, argc, argv);
}
int TF_PickUnusedPortOrDie() {
return tensorflow::internal::PickUnusedPortOrDie();
}
TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType dtype_arg,
void* data, size_t len) {
auto dtype = static_cast<tensorflow::DataType>(dtype_arg);
DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
return new TFE_TensorHandle(tensor, nullptr, nullptr);
}

View File

@ -180,6 +180,25 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString( TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle); TFE_TensorHandle* handle);
typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
// Allows invoking a kernel asynchronously, and explicitly returns a
// notification that can be waited upon. This always executes the kernel in a
// new thread.
// 1. `retvals` and `num_retvals` can only be consumed after
// `TFE_ExecuteOp` returns successfully. They shouldn't be used
// if the return is unsuccessful
// 2. These new APIs cannot be used together with the TFE context level async
// support.
TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status);
// Waits to complete the op execution, and cleans up the notification.
// Errors reported by op execution are set in `status`.
TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
TFE_ExecuteOpNotification* notification, TF_Status* status);
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg); const char* errMsg);
@ -209,6 +228,24 @@ TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice(
TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput(
const char* op_name, int input_index, TF_Status* status); const char* op_name, int input_index, TF_Status* status);
// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined
// if the status is not ok.
TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type,
TF_Status* status);
// Platform specific initialization routine. Very few platforms actually require
// this to be called.
TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
// Platform-specific implementation to return an unused port. (This should used
// in tests only.)
TF_CAPI_EXPORT int TF_PickUnusedPortOrDie();
// Fast path method that makes constructing a single scalar tensor require less
// overhead and copies.
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar(
TF_DataType dtype, void* scalar, size_t len);
#ifdef __cplusplus #ifdef __cplusplus
} /* end extern "C" */ } /* end extern "C" */
#endif #endif

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h" #include "tensorflow/c/c_test_util.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -162,5 +164,137 @@ protocol: "grpc"
TF_DeleteStatus(status); TF_DeleteStatus(status);
} }
TEST(CAPI_EXPERIMENTAL, IsStateful) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
int assign = TF_OpIsStateful("AssignAddVariableOp", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
EXPECT_EQ(assign, 1);
int id = TF_OpIsStateful("Identity", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
EXPECT_EQ(id, 0);
}
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul_op = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
auto* r =
TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(r, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteOp(matmul_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
// Perform a send/recv test. Recv blocks, so they need to be executed
// asynchronously.
TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
// Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
TFE_TensorHandle* m = TestMatrixTensorHandle();
// Build a send op.
TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(send_op, m, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
string tensor_name = "Tensor";
TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(send_op, "client_terminated", true);
// Build a recv op.
TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
tensor_name.size());
TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
send_device.size());
TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
recv_device.size());
TFE_OpSetAttrBool(recv_op, "client_terminated", true);
TFE_TensorHandle* send_retvals;
int send_num_retvals = 0;
auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
&send_num_retvals, status);
TFE_TensorHandle* recv_retvals[1] = {nullptr};
int recv_num_retvals = 1;
auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
&recv_num_retvals, status);
TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(1, product[0]);
EXPECT_EQ(2, product[1]);
EXPECT_EQ(3, product[2]);
EXPECT_EQ(4, product[3]);
TFE_DeleteOp(send_op);
TFE_DeleteOp(recv_op);
TFE_DeleteTensorHandle(m);
TFE_DeleteTensorHandle(recv_retvals[0]);
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -392,26 +392,26 @@ Status ProcessInputs(
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs); input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) { for (int i = 0; i < ninputs; ++i) {
const Node& node = inputs[i].oper->node; Node* node = &inputs[i].oper->node;
int idx = inputs[i].index; int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(&node, idx), fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing input ", i, " into function '", fn_name, "Encountered while processing input ", i, " into function '", fn_name,
"'"); "'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while processing input ", i, "Encountered while processing input ", i,
" into function '", fn_name, "'"); " into function '", fn_name, "'");
input_tensors->emplace_back(&node, idx); input_tensors->emplace_back(node, idx);
const auto& iter = input_nodes->find(&node); const auto& iter = input_nodes->find(node);
if (iter == input_nodes->end()) { if (iter == input_nodes->end()) {
input_nodes->insert({&node, {idx}}); input_nodes->insert({node, {idx}});
} else { } else {
auto& indices = iter->second; auto& indices = iter->second;
if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
return InvalidArgument("TF_Output ", node.name(), ":", idx, return InvalidArgument("TF_Output ", node->name(), ":", idx,
" appears more than once in the input list"); " appears more than once in the input list");
} }
indices.push_back(idx); indices.push_back(idx);
@ -428,16 +428,16 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs); output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) { for (int i = 0; i < noutputs; ++i) {
const Node& node = outputs[i].oper->node; Node* node = &outputs[i].oper->node;
int idx = outputs[i].index; int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(&node, idx), fn_body->graph.IsValidOutputTensor(node, idx),
"Encountered while processing output ", i, " from function '", fn_name, "Encountered while processing output ", i, " from function '", fn_name,
"'"); "'");
TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(&node, idx), TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
"Encountered while creating function '", "Encountered while creating function '",
fn_name, "'"); fn_name, "'");
output_tensors->emplace_back(&node, idx); output_tensors->emplace_back(node, idx);
} }
return Status::OK(); return Status::OK();
} }

View File

@ -50,6 +50,7 @@ tf_cuda_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}) + [ }) + [
"@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@ -143,6 +144,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -21,9 +21,11 @@ limitations under the License.
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/core/platform/host_info.h"
#ifdef TENSORFLOW_EAGER_USE_XLA #ifdef TENSORFLOW_EAGER_USE_XLA
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#endif // TENSORFLOW_EAGER_USE_XLA #endif // TENSORFLOW_EAGER_USE_XLA
@ -79,7 +81,7 @@ tensorflow::Status GetAllRemoteDevices(
const std::vector<string>& remote_workers, const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache, tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) { std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
std::vector<tensorflow::Device*> remote_devices; std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status; tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially. // TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) { for (const string& remote_worker : remote_workers) {
@ -92,7 +94,7 @@ tensorflow::Status GetAllRemoteDevices(
status = s; status = s;
if (s.ok()) { if (s.ok()) {
for (tensorflow::Device* d : *devices) { for (tensorflow::Device* d : *devices) {
remote_devices.push_back(d); remote_devices.emplace_back(d);
} }
} }
n.Notify(); n.Notify();
@ -100,7 +102,7 @@ tensorflow::Status GetAllRemoteDevices(
n.WaitForNotification(); n.WaitForNotification();
} }
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr( std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
new tensorflow::DeviceMgr(remote_devices)); new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
@ -261,13 +263,13 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
std::vector<tensorflow::Device*> devices; std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices( status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0", opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices); &devices);
if (!status->status.ok()) return nullptr; if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr( std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices)); new tensorflow::DeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r = tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get()); new tensorflow::IntraProcessRendezvous(device_mgr.get());
@ -409,6 +411,18 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
: d->name().c_str(); : d->name().c_str();
} }
const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
"The passed in handle is a nullptr");
return nullptr;
}
tensorflow::Device* d = h->handle->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) { TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) { if (h == nullptr || h->handle == nullptr) {
@ -458,13 +472,20 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
TF_Status* status) { TF_Status* status) {
const char* name = op_or_function_name; // Shorthand const char* name = op_or_function_name; // Shorthand
const tensorflow::AttrTypeMap* types; const tensorflow::AttrTypeMap* types;
status->status = tensorflow::AttrTypeMapForOp(name, &types); bool is_function = false;
if (status->status.ok()) return new TFE_Op(ctx, name, types); status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
if (TF_GetCode(status) == TF_NOT_FOUND) { if (status->status.ok()) {
if (ctx->context.FindFunctionByName(name)) { if (is_function && !ctx->context.FindFunctionByName(name)) {
status->status = tensorflow::Status::OK(); status->status = tensorflow::errors::NotFound(
return new TFE_Op(ctx, name, nullptr); "'", name,
"' is neither a type of a primitive operation nor a name "
"of a function registered in binary running on ",
tensorflow::port::Hostname(),
". Make sure the operation or function is "
"registered in the binary running in this process.");
return nullptr;
} }
return new TFE_Op(ctx, name, is_function, types);
} }
return nullptr; return nullptr;
} }
@ -497,12 +518,6 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
unsigned char* is_list, TF_Status* status) { unsigned char* is_list, TF_Status* status) {
TF_AttrType ret; TF_AttrType ret;
if (op->operation.is_function()) {
status->status = tensorflow::errors::Unimplemented(
"TODO(apassos): Support for attributes for TensorFlow functions is not "
"ready yet.");
return TF_ATTR_INT; // The compiler requires that we return something.
}
status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
attr_name, &ret, is_list); attr_name, &ret, is_list);
return ret; return ret;

View File

@ -169,10 +169,33 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index, int dim_index,
TF_Status* status); TF_Status* status);
// Returns the device of the operation that produced `h`.
// If `h` was produced by a copy, returns the destination device of
// the copy. Note that returned device name is not always the device
// holding the tensor handle's memory. If you want the latter, use
// TFE_TensorHandleBackingDeviceName.
// This function will block till the operation that produces `h` has completed.
//
// Device on which the kernel of the operation that produced `h` ran.
//
// If `h` was produced by a copy, returns the destination device of
// the copy.
//
// Note that returned device name is not always the device that owns the memory
// that backs the tensor handle. For the latter see
// TFE_TensorHandleBackingDeviceName.
//
// This function will block till the operation that produces `h` has completed. // This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status); TFE_TensorHandle* h, TF_Status* status);
// Returns the name of the device in whose memory `h` resides.
//
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName(
TFE_TensorHandle* h, TF_Status* status);
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor // Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
// with `h`. On success, `status` is set to OK. On failure, `status` reflects // with `h`. On success, `status` is set to OK. On failure, `status` reflects
// the error and a nullptr is returned. // the error and a nullptr is returned.

View File

@ -93,10 +93,9 @@ struct TFE_TensorDebugInfo {
}; };
struct TFE_Op { struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a TFE_Op(TFE_Context* ctx, const char* op, bool is_function,
// primitive operation. const tensorflow::AttrTypeMap* t)
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : operation(&ctx->context, op, is_function, t) {}
: operation(&ctx->context, op, t) {}
tensorflow::EagerOperation operation; tensorflow::EagerOperation operation;
}; };

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
#include <string.h> #include <string.h>
#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
@ -589,9 +590,22 @@ void TensorHandleCopyBetweenTwoGPUDevices(bool async) {
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices); const int num_devices = TF_DeviceListCount(devices);
bool has_gpu0 = false;
bool has_gpu1 = false;
for (int i = 0; i < num_devices; ++i) {
const char* dev = TF_DeviceListName(devices, i, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string device_name(dev);
if (device_name.find("GPU:0") != string::npos) {
has_gpu0 = true;
}
if (device_name.find("GPU:1") != string::npos) {
has_gpu1 = true;
}
}
const char* kCPUDevice = "CPU:0"; const char* kCPUDevice = "CPU:0";
if (num_devices < 3) { if (!has_gpu0 || !has_gpu1) {
TF_DeleteDeviceList(devices); TF_DeleteDeviceList(devices);
TF_DeleteTensor(t); TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu); TFE_DeleteTensorHandle(hcpu);
@ -781,6 +795,14 @@ TEST(CAPI, TensorHandleNullptr) {
TF_SetStatus(status.get(), TF_OK, ""); TF_SetStatus(status.get(), TF_OK, "");
device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(device_name, nullptr);
ASSERT_EQ("The passed in handle is a nullptr",
string(TF_Message(status.get())));
TF_SetStatus(status.get(), TF_OK, "");
int num_dims = TFE_TensorHandleNumDims(h, status.get()); int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1); ASSERT_EQ(num_dims, -1);
@ -796,6 +818,62 @@ TEST(CAPI, TensorHandleNullptr) {
string(TF_Message(status.get()))); string(TF_Message(status.get())));
} }
TEST(CAPI, TensorHandleDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
const char* backing_device_name =
TFE_TensorHandleBackingDeviceName(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
<< backing_device_name;
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* shape_op = ShapeOp(ctx, hgpu);
TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
// .device of shape is GPU since the op is executed on GPU
device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
// .backing_device of shape is CPU since the tensor is backed by CPU
backing_device_name =
TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
<< backing_device_name;
TFE_DeleteOp(shape_op);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TFE_DeleteTensorHandle(hcpu);
TFE_ContextAsyncWait(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContext(ctx);
}
void Execute_MatMul_CPU(bool async) { void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -104,6 +104,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
return op; return op;
} }
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
TF_Status* status = TF_NewStatus();
TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_OpAddInput(op, a, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
}
TFE_TensorHandle* TestAxisTensorHandle() { TFE_TensorHandle* TestAxisTensorHandle() {
int64_t dims[] = {1}; int64_t dims[] = {1};
int data[] = {1}; int data[] = {1};

View File

@ -37,6 +37,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2();
// Return a matmul op multiplying `a` by `b`. // Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
// Return a shape op fetching the shape of `a`.
TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
// Return an 1-D INT32 tensor containing a single value 1. // Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle(); TFE_TensorHandle* TestAxisTensorHandle();

View File

@ -141,8 +141,9 @@ class GradientTape {
// null. The result is populated with one tensor per target element. // null. The result is populated with one tensor per target element.
Status ComputeGradient( Status ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids, const gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_id, const gtl::ArraySlice<int64> source_tensor_ids,
const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result); std::vector<Gradient*>* result);
@ -396,6 +397,7 @@ template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status InitialGradients( Status InitialGradients(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids, gtl::ArraySlice<int64> target_tensor_ids,
gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape, gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
const OpTape<BackwardFunction, TapeTensor>& op_tape, const OpTape<BackwardFunction, TapeTensor>& op_tape,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) { gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
@ -425,8 +427,13 @@ Status InitialGradients(
"none of operations outputs match expected tensor"); "none of operations outputs match expected tensor");
} }
} else { } else {
// No record of the target tensor found on the tape, so no gradient // This target tensor was not generated by any operation recorded on
// needs to be computed from it. Do nothing. // the tape, so no gradient needs to be computed from it unless this
// target is also a source.
auto source_tensor = sources_that_are_targets.find(id);
if (source_tensor != sources_that_are_targets.end()) {
(*result)[id].push_back(vspace.Ones(source_tensor->second));
}
} }
} else { } else {
(*result)[id].push_back(output_gradients[i]); (*result)[id].push_back(output_gradients[i]);
@ -467,8 +474,9 @@ constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
template <typename Gradient, typename BackwardFunction, typename TapeTensor> template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient( Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace, const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids, const gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids, const gtl::ArraySlice<int64> source_tensor_ids,
const gtl::FlatMap<int64, TapeTensor> sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) { std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(), gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
@ -478,7 +486,8 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
std::vector<int64> op_stack = std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor); InitialStack(state.op_tape, state.op_missing_tensor);
gtl::FlatMap<int64, std::vector<Gradient*>> gradients; gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, Status s = InitialGradients(vspace, target_tensor_ids,
sources_that_are_targets, output_gradients,
tensor_tape_, state.op_tape, &gradients); tensor_tape_, state.op_tape, &gradients);
auto cleanup = [this, &state]() { auto cleanup = [this, &state]() {
if (!persistent_) { if (!persistent_) {

143
tensorflow/c/kernels.cc Normal file
View File

@ -0,0 +1,143 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/kernels.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility.
struct TF_KernelBuilder {
::tensorflow::KernelDefBuilder* cc_builder;
void* (*create_function)(TF_OpKernelConstruction*);
void (*compute_function)(void*, TF_OpKernelContext*);
void (*delete_function)(void*);
};
TF_KernelBuilder* TF_NewKernelBuilder(
const char* op_name, const char* device_name,
void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_func)(void*, TF_OpKernelContext*),
void (*delete_func)(void*)) {
TF_KernelBuilder* result = new TF_KernelBuilder;
result->cc_builder = new ::tensorflow::KernelDefBuilder(op_name);
result->cc_builder->Device(device_name);
result->create_function = create_func;
result->compute_function = compute_func;
result->delete_function = delete_func;
return result;
}
void TF_DeleteKernelBuilder(TF_KernelBuilder* builder) {
DCHECK_NE(builder, nullptr);
delete builder->cc_builder;
delete builder;
}
namespace tensorflow {
namespace {
// An OpKernel whose methods delegate to C function pointers.
class COpKernel : public OpKernel {
public:
explicit COpKernel(OpKernelConstruction* ctx,
void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_func)(void*, TF_OpKernelContext*),
void (*delete_func)(void*))
: OpKernel(ctx), compute_func_(compute_func), delete_func_(delete_func) {
if (create_func != nullptr) {
c_kernel_ =
(*create_func)(reinterpret_cast<TF_OpKernelConstruction*>(ctx));
} else {
c_kernel_ = nullptr;
}
}
void Compute(OpKernelContext* ctx) override {
(*compute_func_)(c_kernel_, reinterpret_cast<TF_OpKernelContext*>(ctx));
}
~COpKernel() override {
if (delete_func_ != nullptr) {
(*delete_func_)(c_kernel_);
}
}
private:
void (*compute_func_)(void*, TF_OpKernelContext* context);
void (*delete_func_)(void*);
void* c_kernel_;
};
// A KernelFactory that returns COpKernel instances.
class KernelBuilderFactory
: public ::tensorflow::kernel_factory::OpKernelFactory {
public:
explicit KernelBuilderFactory(TF_KernelBuilder* builder)
: builder_(builder) {}
::tensorflow::OpKernel* Create(
::tensorflow::OpKernelConstruction* context) override {
return new ::tensorflow::COpKernel(context, builder_->create_function,
builder_->compute_function,
builder_->delete_function);
}
~KernelBuilderFactory() override { TF_DeleteKernelBuilder(builder_); }
private:
TF_KernelBuilder* builder_;
};
} // namespace
} // namespace tensorflow
void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder,
TF_Status* status) {
using tensorflow::register_kernel::Name;
tensorflow::kernel_factory::OpKernelRegistrar(
builder->cc_builder->Build(), name,
absl::make_unique<tensorflow::KernelBuilderFactory>(builder));
TF_SetStatus(status, TF_OK, "");
}
int TF_NumInputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_inputs();
}
int TF_NumOutputs(TF_OpKernelContext* ctx) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return cc_ctx->num_outputs();
}
void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
if (i < 0 || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range");
return;
}
const ::tensorflow::Tensor& cc_tensor(cc_ctx->input(i));
TF_Tensor* result = ::tensorflow::TF_TensorFromTensor(cc_tensor, status);
if (TF_GetCode(status) == TF_OK) {
*tensor = result;
}
}

110
tensorflow/c/kernels.h Normal file
View File

@ -0,0 +1,110 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_KERNELS_H_
#define TENSORFLOW_C_KERNELS_H_
#include "tensorflow/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// --------------------------------------------------------------------------
// C API for TensorFlow Kernels.
//
// This API allows developers to register custom kernel implementations for
// TensorFlow.
//
// See c_api.h header comments for a discussion about API conventions.
//
// Users wishing to extend TensorFlow with new kernels will call
// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with
// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided
// kernels when necessary.
struct TF_KernelBuilder;
struct TF_OpKernelConstruction;
struct TF_OpKernelContext;
// Allocates a new kernel builder and returns a pointer to it.
//
// If non-null, TensorFlow will call create_func when it needs to instantiate
// the kernel. The pointer returned by create_func will be passed to
// compute_func and delete_func, thereby functioning as a "this" pointer for
// referring to kernel instances.
//
// The TF_OpKernelConstruction pointer passed to create_func is owned by
// TensorFlow and will be deleted once create_func returns. It must not be used
// after this.
//
// When TensorFlow needs to perform a computation with this kernel, it will
// call compute_func. This function will receive the pointer returned by
// create_func (or null if no create_func was provided), along with the inputs
// to the computation.
//
// The TF_OpKernelContext pointer received by compute_func is owned by
// TensorFlow and will be deleted once compute_func returns. It must not be used
// after this.
//
// Finally, when TensorFlow no longer needs the kernel, it will call
// delete_func if one is provided. This function will receive the pointer
// returned in `create_func` or nullptr if no `create_func` was provided.
//
// The caller should pass the result of this function to
// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for
// some reason, the kernel builder will not be registered, the caller should
// delete it with TF_DeleteKernelBuilder.
TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder(
const char* op_name, const char* device_name,
void* (*create_func)(TF_OpKernelConstruction*),
void (*compute_func)(void*, TF_OpKernelContext*),
void (*delete_func)(void*));
// Register the given kernel builder with the TensorFlow runtime. If
// registration fails, the given status will be populated.
//
// This call takes ownership of the `builder` pointer.
TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name,
TF_KernelBuilder* builder,
TF_Status* status);
// Deletes the given TF_KernelBuilder. This should be called only if the kernel
// builder is not registered with TensorFlow via TF_RegisterKernelBuilder.
TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder);
// --------------------------------------------------------------------------
// OpKernelContext routines
// TF_NumInputs returns the number of inputs available in ctx.
TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx);
// TF_NumOutputs returns the number of outputs to be placed in *ctx by the
// kernel.
TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx);
// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is
// populated and its ownership is passed to the caller. In any other case,
// *tensor is not modified.
//
// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE.
TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i,
TF_Tensor** tensor, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_KERNELS_H_

View File

@ -0,0 +1,194 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def.pb_text.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
struct MyCustomKernel {
bool created;
bool compute_called;
};
static bool delete_called = false;
static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
struct MyCustomKernel* s = new struct MyCustomKernel;
s->created = true;
s->compute_called = false;
return s;
}
static void MyComputeFunc(void* kernel, TF_OpKernelContext* ctx) {
struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
s->compute_called = true;
}
static void MyDeleteFunc(void* kernel) {
struct MyCustomKernel* s = static_cast<struct MyCustomKernel*>(kernel);
EXPECT_TRUE(s->created);
EXPECT_TRUE(s->compute_called);
delete_called = true;
delete s;
}
namespace tensorflow {
static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
const char* op_name,
Status* status) {
NodeDef def;
def.set_op(op_name);
def.set_device(device_name);
def.add_input("input1");
def.add_input("input2");
return CreateOpKernel(DeviceType(device_name), nullptr, nullptr, def, 1,
status);
}
// Tests registration of a single C kernel and checks that calls through the
// C/C++ boundary are being made.
TEST(TestKernel, TestRegisterKernelBuilder) {
const char* kernel_name = "SomeKernelName";
const char* op_name = "FooOp";
const char* device_name = "FakeDeviceName1";
REGISTER_OP(op_name)
.Input("input1: double")
.Input("input2: uint8")
.Output("output1: uint8");
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
KernelList list;
list.ParseFromArray(buf->data, buf->length);
ASSERT_EQ(1, list.kernel_size());
ASSERT_EQ(device_name, list.kernel(0).device_type());
TF_DeleteBuffer(buf);
TF_DeleteStatus(status);
}
{
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr);
}
ASSERT_TRUE(delete_called);
}
class DummyDevice : public DeviceBase {
public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator();
}
private:
bool save_;
};
TEST(TestKernel, TestInputAndOutputCount) {
const char* kernel_name = "InputOutputCounterKernel";
const char* op_name = "BarOp";
const char* device_name = "FakeDeviceName2";
REGISTER_OP(op_name)
.Input("input1: double")
.Input("input2: uint8")
.Output("output1: uint8");
static int num_inputs = 0;
static int num_outputs = 0;
// A kernel whose Compute function has a side-effect of updating num_inputs
// and num_outputs. Various functions on TF_OpKernelContext are also
// exercised.
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
num_inputs = TF_NumInputs(ctx);
num_outputs = TF_NumOutputs(ctx);
TF_Tensor* input = nullptr;
TF_Status* s = TF_NewStatus();
TF_GetInput(ctx, 0, &input, s);
EXPECT_EQ(TF_OK, TF_GetCode(s)) << "Failed to get input: " << TF_Message(s);
EXPECT_EQ(123, *static_cast<tensorflow::uint8*>(TF_TensorData(input)));
TF_GetInput(ctx, -1, &input, s);
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
TF_GetInput(ctx, 3, &input, s);
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s));
TF_DeleteStatus(s);
if (input != nullptr) {
TF_DeleteTensor(input);
}
};
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
my_compute_func, nullptr);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr, false);
p.device = &dummy_device;
Tensor t(tensorflow::uint8(123));
gtl::InlinedVector<TensorValue, 4> inputs;
// Simulate 2 inputs
inputs.emplace_back(&t);
inputs.emplace_back();
p.inputs = &inputs;
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
p.op_kernel = kernel.get();
OpKernelContext ctx(&p);
kernel->Compute(&ctx);
ASSERT_EQ(2, num_inputs);
ASSERT_EQ(1, num_outputs);
}
}
} // namespace tensorflow

View File

@ -160,4 +160,17 @@ void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
} }
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status) {
mutex_lock l(graph->mu);
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
new_src.index, &dst->node);
if (status->status.ok()) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
RecordMutation(graph, *dst, "adding input tensor");
}
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,6 +34,7 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device);
// Updates 'dst' to consume 'new_src'.
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status); TF_Status* status);
@ -65,6 +66,13 @@ std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// because I couldn't get SWIG to work otherwise. // because I couldn't get SWIG to work otherwise.
void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
size_t proto_len, TF_Status* status); size_t proto_len, TF_Status* status);
// This method is used to add a new input edge to 'dst', which must be a While
// op. The While op's "T" attribute must have already been updated to include
// the new edge. This is used to construct tf.while_loop gradients.
void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_ #endif // TENSORFLOW_C_PYTHON_API_H_

View File

@ -489,6 +489,7 @@ tf_gen_op_wrappers_cc(
"image_ops", "image_ops",
"io_ops", "io_ops",
"linalg_ops", "linalg_ops",
"list_ops",
"logging_ops", "logging_ops",
"lookup_ops", "lookup_ops",
"manip_ops", "manip_ops",

View File

@ -133,5 +133,6 @@ filegroup(
"testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**", "testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**", "testdata/half_plus_two/**",
"testdata/half_plus_two_v2/**",
]), ]),
) )

View File

@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
/// SavedModel text format proto filename. /// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
/// SavedModel legacy init op key. /// SavedModel legacy init op collection key. Used in v1 SavedModels.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
/// SavedModel main op key. /// SavedModel main op collection key. Used in v1 SavedModels.
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
/// Directory in which to save the SavedModel variables. /// Directory in which to save the SavedModel variables.
@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
/// SavedModel variables filename. /// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables"; constexpr char kSavedModelVariablesFilename[] = "variables";
/// SavedModel SignatureDef keys for the initialization and train ops. Used in
/// V2 SavedModels.
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_

View File

@ -122,38 +122,58 @@ Status RunOnce(const RunOptions& run_options,
return run_status; return run_status;
} }
bool HasMainOp(const MetaGraphDef& meta_graph_def) { // RunInitOp will return OK if the initialization op was run successfully.
const auto& collection_def_map = meta_graph_def.collection_def(); // An empty init_op_name indicates that there are no init ops to run.
if (collection_def_map.find(kSavedModelMainOpKey) != Status RunInitOp(const RunOptions& run_options, const string& export_dir,
collection_def_map.end()) {
return true;
}
return false;
}
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
const MetaGraphDef& meta_graph_def, const MetaGraphDef& meta_graph_def,
const std::vector<AssetFileDef>& asset_file_defs, const std::vector<AssetFileDef>& asset_file_defs,
Session* session, const string& main_op_key) { Session* session, const string& init_op_name) {
LOG(INFO) << "Running MainOp with key " << main_op_key if (!init_op_name.empty()) {
<< " on SavedModel bundle."; LOG(INFO) << "Running initialization op on SavedModel bundle.";
const auto& collection_def_map = meta_graph_def.collection_def();
const auto main_op_it = collection_def_map.find(main_op_key);
if (main_op_it != collection_def_map.end()) {
if (main_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
std::vector<std::pair<string, Tensor>> inputs; std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata; RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0); return RunOnce(run_options, inputs, {}, {init_op_name},
return RunOnce(run_options, inputs, {}, {string(main_op_name)},
nullptr /* outputs */, &run_metadata, session); nullptr /* outputs */, &run_metadata, session);
} }
return Status::OK(); return Status::OK();
} }
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir, Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name, const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name, const StringPiece variable_filename_const_op_name,
@ -193,6 +213,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) { std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def(); const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey); const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) { if (assets_it == collection_def_map.end()) {
@ -227,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get())); asset_file_defs, bundle->session.get()));
if (HasMainOp(bundle->meta_graph_def)) { string init_op_name;
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, TF_RETURN_IF_ERROR(
bundle->meta_graph_def, asset_file_defs, GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
bundle->session.get(), kSavedModelMainOpKey)); TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
} else { asset_file_defs, bundle->session.get(),
TF_RETURN_IF_ERROR(RunMainOp( init_op_name));
run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
bundle->session.get(), kSavedModelLegacyInitOpKey));
}
return Status::OK(); return Status::OK();
} }

View File

@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
"cc/saved_model/testdata/half_plus_two_main_op/00000123"; "cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] = constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123"; "cc/saved_model/testdata/half_plus_two/00000123";
constexpr char kTestDataInitOpV2[] =
"cc/saved_model/testdata/half_plus_two_v2/00000123";
class LoaderTest : public ::testing::Test { class LoaderTest : public ::testing::Test {
protected: protected:
@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir)); EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
} }
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1 @@
asset-file-contents

View File

@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
} }
// Generate methods for args (inputs). // Generate methods for args (inputs).
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, Status GenArgMethods(const tf2xla::Config& config,
const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) { const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size(); size_t num_args = ps.parameters_size();
if (config.feed_size() != num_args) { if (config.feed_size() != num_args) {
@ -174,9 +175,10 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
} }
for (int i = 0; i < num_args; ++i) { for (int i = 0; i < num_args; ++i) {
std::vector<std::pair<string, string>> rewrites; std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); TF_RETURN_IF_ERROR(
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"( const string code = R"(
void set_arg{{NAME}}_data(void* data) { void set_arg{{NAME}}_data(const void* data) {
set_arg_data({{I}}, data); set_arg_data({{I}}, data);
} }
{{TYPE}}* arg{{NAME}}_data() { {{TYPE}}* arg{{NAME}}_data() {
@ -204,7 +206,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
// Generate methods for results (outputs). // Generate methods for results (outputs).
Status GenResultMethods(const tf2xla::Config& config, Status GenResultMethods(const tf2xla::Config& config,
const xla::ProgramShape& ps, string* methods) { const xla::ProgramShapeProto& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) { if (ps.result().element_type() != xla::TUPLE) {
// The XlaCompiler we use to build the xla computation always generates a // The XlaCompiler we use to build the xla computation always generates a
// tuple result, and we rely on this to simplify code generation. // tuple result, and we rely on this to simplify code generation.
@ -217,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
} }
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector<std::pair<string, string>> rewrites; std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(AddRewritesForShape(
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"( string code = R"(
{{TYPE}}* result{{NAME}}_data() { {{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}})); return static_cast<{{TYPE}}*>(result_data({{I}}));
@ -336,7 +338,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
ExtractEntryParamBufferInfos(buffer_infos); ExtractEntryParamBufferInfos(buffer_infos);
std::vector<BufferInfo> buffer_infos_for_temps = std::vector<BufferInfo> buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos); ExtractTempBufferInfos(buffer_infos);
const xla::ProgramShape& ps = compile_result.program_shape; const xla::ProgramShapeProto& ps = compile_result.program_shape;
string methods_arg, methods_result; string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
@ -548,8 +550,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static const char** StaticResultNames() {{RESULT_NAMES_CODE}} static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
// Shape of the args and results. // Shape of the args and results.
static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShapeProto* StaticProgramShape() {
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape; return kShape;
} }
@ -587,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{METHODS_RESULT}}\n", methods_result}, {"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end}, {"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start}, {"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim}, metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
@ -615,11 +617,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
Status GenerateMetadata(const CodegenOpts& opts, Status GenerateMetadata(const CodegenOpts& opts,
const CompileResult& compile_result, const CompileResult& compile_result,
MetadataResult* metadata_result) { MetadataResult* metadata_result) {
std::unique_ptr<xla::ProgramShape> program_shape; std::unique_ptr<xla::ProgramShapeProto> program_shape;
if (opts.gen_program_shape) { if (opts.gen_program_shape) {
program_shape = program_shape =
absl::make_unique<xla::ProgramShape>(compile_result.program_shape); absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the // The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save // rest of our metadata, so clear them out to avoid confusion and save
@ -631,8 +633,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
// a shim that evaluates to nullptr, which is what we want. // a shim that evaluates to nullptr, which is what we want.
ProtobufToEmbed program_shape_protobuf{ ProtobufToEmbed program_shape_protobuf{
CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", CreateUniqueIdentifier(opts, "ProgramShapeProto"),
program_shape.get()}; "xla::ProgramShapeProto", program_shape.get()};
ProtobufToEmbed hlo_profile_printer_data_protobuf{ ProtobufToEmbed hlo_profile_printer_data_protobuf{
CreateUniqueIdentifier(opts, "HloProfilePrinterData"), CreateUniqueIdentifier(opts, "HloProfilePrinterData"),

View File

@ -57,7 +57,7 @@ struct MetadataResult {
std::vector<string> header_variable_decls; std::vector<string> header_variable_decls;
// program_shape_access_shim is a C++ expression that constructs the // program_shape_access_shim is a C++ expression that constructs the
// xla::ProgramShape instance for the CompileResult passed to // xla::ProgramShapeProto instance for the CompileResult passed to
// GenerateMetadata. // GenerateMetadata.
string program_shape_access_shim; string program_shape_access_shim;

View File

@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {})); 5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( compile_result.program_shape =
{ xla::ShapeUtil::MakeProgramShape(
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), {
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
}, xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
xla::ShapeUtil::MakeTupleShape( },
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
.ToProto();
compile_result.entry_point = "entry_point"; compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8; compile_result.pointer_size = 8;

View File

@ -22,7 +22,7 @@ extern "C" void entry_point(
void* result, const xla::ExecutableRunOptions* run_options, void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps, tensorflow::int64* profile_counters); const void** args, void** temps, tensorflow::int64* profile_counters);
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
namespace foo { namespace foo {
@ -114,7 +114,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
// with dim indices specifying which value. No bounds checking is performed // with dim indices specifying which value. No bounds checking is performed
// on dim indices. // on dim indices.
void set_arg0_data(void* data) { void set_arg0_data(const void* data) {
set_arg_data(0, data); set_arg_data(0, data);
} }
float* arg0_data() { float* arg0_data() {
@ -132,7 +132,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
arg_data(0)))[dim0][dim1]; arg_data(0)))[dim0][dim1];
} }
void set_arg_myfeed_data(void* data) { void set_arg_myfeed_data(const void* data) {
set_arg_data(0, data); set_arg_data(0, data);
} }
float* arg_myfeed_data() { float* arg_myfeed_data() {
@ -150,7 +150,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
arg_data(0)))[dim0][dim1]; arg_data(0)))[dim0][dim1];
} }
void set_arg1_data(void* data) { void set_arg1_data(const void* data) {
set_arg_data(1, data); set_arg_data(1, data);
} }
tensorflow::int64* arg1_data() { tensorflow::int64* arg1_data() {
@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
} }
// Shape of the args and results. // Shape of the args and results.
static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShapeProto* StaticProgramShape() {
static const xla::ProgramShape* kShape = []() { static const xla::ProgramShapeProto* kShape = []() {
xla::ProgramShape* proto = new xla::ProgramShape; xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
return proto; return proto;
}(); }();
return kShape; return kShape;

View File

@ -56,17 +56,23 @@ Status CompileXla(xla::CompileOnlyClient* client,
return errors::Unknown("Couldn't get XLA program shape: ", return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message()); pshape_or.status().error_message());
} }
compile_result->program_shape = *pshape_or.ValueOrDie(); compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
xla::ProgramShape* pshape = &compile_result->program_shape; xla::ProgramShapeProto* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size()); // AotXlaComputationInstance::argument_layouts is a vector of Shape
// pointers. Accumulate the Shape objects themselves in a separate vector
// while building the vector of pointers.
std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) { for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i)); arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
arg_layout_ptrs[i] = &arg_layouts[i];
} }
xla::CompileOnlyClient::AotXlaComputationInstance instance; xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation; instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts); instance.argument_layouts = std::move(arg_layout_ptrs);
instance.result_layout = &pshape->result(); xla::Shape result_shape(pshape->result());
instance.result_layout = &result_shape;
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>> xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts); aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) { if (!aot_or.ok()) {

View File

@ -33,9 +33,9 @@ namespace tfcompile {
struct CompileResult { struct CompileResult {
// Contains object file and meta-info. // Contains object file and meta-info.
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot; std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShape program_shape; // Static shape of args and results. xla::ProgramShapeProto program_shape; // Static shape of args and results.
string entry_point; // Name of generated function. string entry_point; // Name of generated function.
int pointer_size = 0; // Size of a pointer in bytes. int pointer_size = 0; // Size of a pointer in bytes.
}; };
// CompileGraph compiles the graph_def into an object file containing a function // CompileGraph compiles the graph_def into an object file containing a function

View File

@ -526,13 +526,15 @@ TEST(TFCompileTest, ProgramShape) {
// muladd has the program shape defined. // muladd has the program shape defined.
MatMulAndAddComp muladd; MatMulAndAddComp muladd;
const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr); ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2); ASSERT_EQ(muladd_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); EXPECT_TRUE(
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
EXPECT_TRUE(
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
const xla::Shape& muladd_result = muladd_shape->result(); const xla::Shape muladd_result(muladd_shape->result());
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
const xla::Shape& muladd_result0 = const xla::Shape& muladd_result0 =

View File

@ -23,7 +23,6 @@ package(
load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@ -38,7 +37,7 @@ cc_library(
":xla_cpu_device", ":xla_cpu_device",
":xla_cpu_jit", ":xla_cpu_jit",
"//tensorflow/compiler/plugin", "//tensorflow/compiler/plugin",
] + if_cuda_is_configured([ ] + if_cuda([
":xla_gpu_device", ":xla_gpu_device",
":xla_gpu_jit", ":xla_gpu_jit",
]), ]),
@ -51,6 +50,7 @@ cc_library(
deps = [ deps = [
":jit_compilation_passes", ":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:cpu_plugin",
@ -76,10 +76,10 @@ cc_library(
srcs = ["xla_cpu_device.cc"], srcs = ["xla_cpu_device.cc"],
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
":flags",
":jit_compilation_passes", ":jit_compilation_passes",
":xla_device", ":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
@ -210,6 +210,18 @@ cc_library(
# Internal targets below this point. # Internal targets below this point.
cc_library(
name = "flags",
srcs = ["flags.cc"],
hdrs = ["flags.h"],
visibility = [":friends"],
deps = [
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "common", name = "common",
srcs = [ srcs = [
@ -256,6 +268,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
@ -268,6 +281,7 @@ cc_library(
"//tensorflow/core/kernels:variable_ops", "//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )
@ -487,6 +501,7 @@ cc_library(
deps = [ deps = [
":common", ":common",
":encapsulate_util", ":encapsulate_util",
":flags",
":shape_inference_helpers", ":shape_inference_helpers",
":union_find", ":union_find",
":xla_cluster_util", ":xla_cluster_util",
@ -494,8 +509,6 @@ cc_library(
"//tensorflow/cc:ops", "//tensorflow/cc:ops",
"//tensorflow/cc:scope_internal", "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:build_xla_ops_pass_flags",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:resource_operation_table",
@ -724,7 +737,10 @@ tf_custom_op_py_library(
visibility = [ visibility = [
":friends", ":friends",
], ],
deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], deps = [
"//tensorflow/compiler/jit/ops:xla_ops_grad",
"//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
],
) )
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/control_flow_ops.h" #include "tensorflow/cc/ops/control_flow_ops.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/dump_graph.h"
@ -320,10 +320,10 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
return IsXlaCompiledKernel(*n); return IsXlaCompiledKernel(*n);
}); });
bool lazy_compilation_enabled = enable_lazy_compilation_ bool lazy_compilation_enabled =
? *enable_lazy_compilation_ enable_lazy_compilation_
: legacy_flags::GetBuildXlaOpsPassFlags() ? *enable_lazy_compilation_
.tf_xla_enable_lazy_compilation; : GetBuildXlaOpsPassFlags().tf_xla_enable_lazy_compilation;
for (Node* n : xla_compiled_kernels) { for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(

View File

@ -42,14 +42,8 @@ class BuildXlaOpsTest : public ::testing::Test {
.ok()); .ok());
} }
void TearDown() override {
for (Device* device : devices_) {
delete device;
}
}
private: private:
std::vector<Device*> devices_; std::vector<std::unique_ptr<Device>> devices_;
}; };
using ::tensorflow::testing::FindNodeByName; using ::tensorflow::testing::FindNodeByName;

View File

@ -59,8 +59,9 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
SessionOptions options; SessionOptions options;
auto* device_count = options.config.mutable_device_count(); auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1}); device_count->insert({"CPU", 1});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices( TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices_)); options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) { for (const auto& fdef : flib) {
@ -69,7 +70,7 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
lib_def_ = absl::make_unique<FunctionLibraryDefinition>( lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto); OpRegistry::Global(), proto);
OptimizerOptions opts; OptimizerOptions opts;
device_mgr_ = absl::make_unique<DeviceMgr>(devices_); device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>( pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
@ -77,7 +78,6 @@ class CreateXlaLaunchOpTest : public ::testing::Test {
} }
FunctionLibraryRuntime* flr_; FunctionLibraryRuntime* flr_;
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;

View File

@ -86,7 +86,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
continue; continue;
} else if (src_xla_computation && !dst_xla_computation) { } else if (src_xla_computation && !dst_xla_computation) {
if (src_outside_compilation) { if (src_outside_compilation) {
// Case 1d: outside compilation to host computation control edge. // Case 1c: outside compilation to host computation control edge.
edges_to_remove.push_back(e); edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>( TF_RETURN_IF_ERROR(AppendToListAttr<string>(
@ -94,7 +94,7 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
} }
} else if (!src_xla_computation && dst_xla_computation) { } else if (!src_xla_computation && dst_xla_computation) {
if (dst_outside_compilation) { if (dst_outside_compilation) {
// Case 1d: host computation control to outside compilation edge. // Case 1c: host computation control to outside compilation edge.
edges_to_remove.push_back(e); edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>( TF_RETURN_IF_ERROR(AppendToListAttr<string>(
@ -103,40 +103,24 @@ Status ProcessControlEdges(Graph* g, const string& xla_computation_attr_name,
} else { // src_xla_computation && dst_xla_computation } else { // src_xla_computation && dst_xla_computation
if (*src_xla_computation != *dst_xla_computation) { if (*src_xla_computation != *dst_xla_computation) {
if (src_outside_compilation && dst_outside_compilation) { if (src_outside_compilation && dst_outside_compilation) {
// Case 1c: outside compilation to outside compilation control edge. // Case 1b: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e); edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>( TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name())); e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
} else if (src_outside_compilation && !dst_outside_compilation) { } else if (src_outside_compilation && !dst_outside_compilation) {
// Case 1b: outside compilation to another XLA computaition control // Case 1a: outside compilation to another XLA computaition control
// edge. // edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>( TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->src(), kXlaConnectedToOtherXlaComputationAttrName, e->src(), kXlaConnectedToOtherXlaComputationAttrName,
*dst_xla_computation)); *dst_xla_computation));
} else if (!src_outside_compilation && dst_outside_compilation) { } else if (!src_outside_compilation && dst_outside_compilation) {
// Case 1b: another XLA computaition to outside compilation control // Case 1a: another XLA computaition to outside compilation control
// edge. // edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>( TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName, e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
*src_xla_computation)); *src_xla_computation));
} }
} else { // *src_xla_computation == *dst_xla_computation
if (src_outside_compilation && dst_outside_compilation) {
if (*src_outside_compilation != *dst_outside_compilation) {
// Case 1c: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
}
} else if (src_outside_compilation && !dst_outside_compilation) {
// Case 1a: outside compilation to its XLA computation control edge.
ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
} else if (!src_outside_compilation && dst_outside_compilation) {
// Case 1a: XLA computation to outside compilation in it control edge.
ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
}
} }
} }
} }
@ -181,12 +165,6 @@ Status ProcessXlaToXlaDataEdges(Graph* g,
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()}); edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "XLA -> XLA edge: " << e->DebugString(); VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
} }
} else { // *src_xla_computation == *dst_xla_computation
if (src_outside_compilation && dst_outside_compilation &&
*src_outside_compilation != *dst_outside_compilation) {
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
}
} }
} }
@ -263,7 +241,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
// Remove the edge from host to outside compilation. Add a placeholder as // Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input. // outside compilation node input.
std::map<string, Node*> placeholders; std::map<std::pair<string, int>, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) { for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id); Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e; const Edge* e;
@ -275,9 +253,10 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
// Find or create placeholder node. // Find or create placeholder node.
string new_name = string new_name =
edges[i].is_host_to_outside_compilation edges[i].is_host_to_outside_compilation
? absl::StrCat(src->name(), "_host_to_oc_placeholder") ? absl::StrCat(src->name(), "_host_to_oc_placeholder_", src_output)
: absl::StrCat(src->name(), "_oc_to_host_placeholder"); : absl::StrCat(src->name(), "_oc_to_host_placeholder_", src_output);
auto iter = placeholders.find(new_name); auto placeholder_index = std::make_pair(src->name(), src_output);
auto iter = placeholders.find(placeholder_index);
Node* placeholder_node; Node* placeholder_node;
if (iter == placeholders.end()) { if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder"); NodeDefBuilder placeholder_builder(new_name, "Placeholder");
@ -310,7 +289,7 @@ Status ProcessDataEdgeBetweenOutsideCompilationAndHostComputation(
Status s; Status s;
placeholder_node = g->AddNode(placeholder_def, &s); placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s); TF_RETURN_IF_ERROR(s);
placeholders[new_name] = placeholder_node; placeholders[placeholder_index] = placeholder_node;
} else { } else {
placeholder_node = iter->second; placeholder_node = iter->second;
} }
@ -594,14 +573,244 @@ Status AddControlDependencies(
return Status::OK(); return Status::OK();
} }
// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
// `PreprocessEdgesBetweenOutsideCompilations` for details.
Status PreprocessControlEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Gather edges to remove. We should not remove the edge while iterating.
std::vector<const Edge*> edges_to_remove;
for (const Edge* e : g->edges()) {
if (!e->IsControlEdge()) {
continue;
}
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (src_outside_compilation && dst_outside_compilation) {
if (*src_outside_compilation != *dst_outside_compilation) {
// Case 1a: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
e->src()->name()));
}
} else if (src_outside_compilation && !dst_outside_compilation) {
// Case 1b: outside compilation to its XLA computation control edge.
ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
} else if (!src_outside_compilation && dst_outside_compilation) {
// Case 1b: XLA computation to outside compilation in it control edge.
ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
}
}
for (auto e : edges_to_remove) {
g->RemoveEdge(e);
}
return Status::OK();
}
// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
// `PreprocessEdgesBetweenOutsideCompilations` for details.
Status PreprocessDataEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Gather edges between outside compilation and host computation. Notice that
// we do not store `Edge*` directly because we remove some nodes while adding
// Identity nodes, and those Edge pointers might be invalidated.
struct EdgeInfo {
int dst_input, dst_node_id;
};
std::vector<EdgeInfo> edges;
for (const Edge* e : g->edges()) {
if (e->IsControlEdge()) {
continue;
}
auto src_outside_compilation =
GetStringAttr(*e->src(), outside_compilation_attr_name);
auto dst_outside_compilation =
GetStringAttr(*e->dst(), outside_compilation_attr_name);
if (src_outside_compilation && dst_outside_compilation &&
*src_outside_compilation != *dst_outside_compilation) {
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "Oc -> oc edge: " << e->DebugString();
}
}
// Remove the edge from host to outside compilation. Add a placeholder as
// outside compilation node input.
std::map<std::pair<string, int>, Node*> placeholders;
for (int i = 0; i < edges.size(); i++) {
Node* dst = g->FindNodeId(edges[i].dst_node_id);
const Edge* e;
TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
Node* src = e->src();
int src_output = e->src_output(), dst_input = e->dst_input();
g->RemoveEdge(e);
// Find or create placeholder node.
string new_name =
absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
auto placeholder_index = std::make_pair(src->name(), src_output);
auto iter = placeholders.find(placeholder_index);
Node* placeholder_node;
if (iter == placeholders.end()) {
NodeDefBuilder placeholder_builder(new_name, "Placeholder");
placeholder_builder.Attr("dtype", src->output_type(src_output));
string outside_compilation_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
outside_compilation_attr_name,
&outside_compilation_attr));
placeholder_builder.Attr(outside_compilation_attr_name,
outside_compilation_attr);
placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
src->name());
placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
src_output);
NodeDef placeholder_def;
TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
Status s;
placeholder_node = g->AddNode(placeholder_def, &s);
TF_RETURN_IF_ERROR(s);
placeholders[placeholder_index] = placeholder_node;
} else {
placeholder_node = iter->second;
}
g->AddEdge(placeholder_node, 0, dst, dst_input);
// Replace `e->dst()` because its input node changed.
NodeDef new_def = dst->def();
*new_def.mutable_input(dst_input) = placeholder_node->name();
TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
// Other edge in `edges` might have `e->dst()` as src or dst
// node. Before removing `e->dst()`, replace those edges with
// corresponding edges for `dst_replace_node`.
for (int j = i + 1; j < edges.size(); j++) {
if (edges[j].dst_node_id == edges[i].dst_node_id) {
edges[j].dst_node_id = dst_replace_node->id();
}
}
}
return Status::OK();
}
// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
// `PostprocessEdgesBetweenOutsideCompilations` for details.
Status PostprocessDataEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Gather all outside compilation to outside compilation nodes.
std::vector<Node*> placeholder_nodes;
for (Node* n : g->nodes()) {
if (n->type_string() == "Placeholder" &&
HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
placeholder_nodes.push_back(n);
}
}
// Remove the placeholder nodes, and reconnect original edge.
auto node_name_index = g->BuildNodeNameIndex();
for (auto n : placeholder_nodes) {
string node_name;
int node_src_output;
TF_RETURN_IF_ERROR(GetNodeAttr(
n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
TF_RETURN_IF_ERROR(GetNodeAttr(
n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
auto iter = node_name_index.find(node_name);
if (iter == node_name_index.end()) {
return errors::Internal(
"Cannot find original node for oc -> host placeholder node ",
node_name);
}
// Change all usage node to use the original node instead.
Node* original_node = iter->second;
std::vector<const Edge*> control_edges;
std::vector<OutEdgeInfo> data_edges;
for (auto e : n->out_edges()) {
if (e->IsControlEdge()) {
control_edges.push_back(e);
} else {
data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
}
}
for (const Edge* e : control_edges) {
g->AddControlEdge(original_node, e->dst());
g->RemoveEdge(e);
}
for (int i = 0; i < data_edges.size(); i++) {
Node* dst = data_edges[i].dst;
NodeDef new_def = dst->def();
int dst_input = data_edges[i].dst_input;
*new_def.mutable_input(dst_input) =
absl::StrCat(original_node->name(), ":", node_src_output);
TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
const Edge* edge_to_replace = nullptr;
TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
g->RemoveEdge(edge_to_replace);
g->AddEdge(original_node, node_src_output, replace_node, dst_input);
// Other edges might have `dst` as dst node. Update those edges with
// `replace_node`.
for (int j = i + 1; j < data_edges.size(); j++) {
if (data_edges[j].dst == dst) {
data_edges[j].dst = replace_node;
}
}
// Other placeholder node might have `dst` as original node. Update
// `node_name_index` with `replace_node`.
node_name_index[replace_node->name()] = replace_node;
}
// Remove placeholder node.
g->RemoveNode(n);
}
return Status::OK();
}
// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
// `PostprocessEdgesBetweenOutsideCompilations` for details.
Status PostprocessControlEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
auto node_name_index = g->BuildNodeNameIndex();
// Reconnect outside compilation to outside compilation control edge.
for (Node* n : g->nodes()) {
std::vector<string> control_deps;
Status s =
GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
&control_deps);
if (!s.ok()) {
if (s.code() != error::NOT_FOUND) {
return s;
} else {
continue;
}
} else {
n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
for (const string& control_input : control_deps) {
auto iter = node_name_index.find(control_input);
if (iter == node_name_index.end()) {
return errors::Internal("Cannot find original node for ",
control_input);
}
g->AddControlEdge(iter->second, n);
}
}
}
return Status::OK();
}
} // namespace } // namespace
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes"; const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
const char kXlaConnectedToXlaComputationAttrName[] =
"_xla_connected_to_xla_computation";
const char kXlaConnectedFromXlaComputationAttrName[] =
"_xla_connected_from_xla_computation";
const char kXlaConnectedToOtherXlaComputationAttrName[] = const char kXlaConnectedToOtherXlaComputationAttrName[] =
"_xla_connected_to_other_xla_computation"; "_xla_connected_to_other_xla_computation";
const char kXlaConnectedFromOtherXlaComputationAttrName[] = const char kXlaConnectedFromOtherXlaComputationAttrName[] =
@ -616,6 +825,15 @@ const char kHostToOutsideCompilationOriginalNodeAttrName[] =
"_xla_host_to_oc_node_name"; "_xla_host_to_oc_node_name";
const char kHostToOutsideCompilationSrcOutputAttrName[] = const char kHostToOutsideCompilationSrcOutputAttrName[] =
"_xla_host_to_oc_src_output"; "_xla_host_to_oc_src_output";
const char kXlaConnectedToXlaComputationAttrName[] =
"_xla_connected_to_xla_computation";
const char kXlaConnectedFromXlaComputationAttrName[] =
"_xla_connected_from_xla_computation";
const char kOutsideCompilationOriginalNodeAttrName[] =
"_xla_oc_to_oc_node_name";
const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
"_xla_control_dependencies_within_xla_cluster";
Status PerformStaticShapeInferenceBeforeEncapsulation( Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name, Graph* g, const string& xla_computation_attr_name,
@ -699,4 +917,39 @@ Status PostprocessForEncapsulation(
return Status::OK(); return Status::OK();
} }
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
// Remove edges from source node to outside compilation nodes, and edges
// from outside compilation nodes to sink node.
std::vector<const Edge*> edges_to_remove;
for (const Edge* e : g->source_node()->out_edges()) {
if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
edges_to_remove.push_back(e);
}
}
for (const Edge* e : g->sink_node()->in_edges()) {
if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
edges_to_remove.push_back(e);
}
}
for (auto e : edges_to_remove) {
g->RemoveEdge(e);
}
TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
g, outside_compilation_attr_name));
TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
g, outside_compilation_attr_name));
return Status::OK();
}
Status PostprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name) {
TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
g, outside_compilation_attr_name));
TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
g, outside_compilation_attr_name));
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -44,14 +44,6 @@ Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name, Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name); const string& outside_compilation_attr_name);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
extern const char kXlaConnectedToXlaComputationAttrName[];
// Attribute indicating that this node has control dependency on some ops in
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
// Attribute indicating that some ops in other XLA computation has control // Attribute indicating that some ops in other XLA computation has control
// dependency on this node. Attribute value will be a list of string (XLA // dependency on this node. Attribute value will be a list of string (XLA
// computation names). // computation names).
@ -81,6 +73,14 @@ extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
// int (src_output for original edge). // int (src_output for original edge).
extern const char kOutsideCompilationToHostSrcOutputAttrName[]; extern const char kOutsideCompilationToHostSrcOutputAttrName[];
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
extern const char kXlaConnectedToXlaComputationAttrName[];
// Attribute indicating that this node has control dependency on some ops in
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a // Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an host node. Attribute value will be string // temporary input node for an host node. Attribute value will be string
// (original input node name). // (original input node name).
@ -91,19 +91,31 @@ extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
// for original edge). // for original edge).
extern const char kHostToOutsideCompilationSrcOutputAttrName[]; extern const char kHostToOutsideCompilationSrcOutputAttrName[];
// Preprocesses the graph for encapsulation. It will perform the following // Attribute indicating that this is an Placeholder node added to act as a
// operations in order: // temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
extern const char kOutsideCompilationOriginalNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// int (src_output for original edge).
extern const char kOutsideCompilationSrcOutputAttrName[];
// Attribute indicating that this node has control dependencies on some other
// nodes within the same XLA cluster. Attribute value will be a list of string
// (node names).
extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
// Preprocesses edges between different XLA clusters for encapsulation. It will
// perform the following operations in order:
// //
// 1a. For control edges between outside compilation and its XLA computation, // 1a. For control edges between outside compilation and another XLA
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
// outside compilation node.
// 1b. For control edges between outside compilation and another XLA
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName // computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
// = XLA computation node name" to the outside compilation node. // = XLA computation node name" to the outside compilation node.
// 1c. For control edges between different outside compilations, remove the edge // 1b. For control edges between different outside compilations (in different
// and add attr "kXlaControlDependenciesAttrName = src node name" to dst // XLA computations), remove the edge and add attr
// node. // "kXlaControlDependenciesAttrName = src node name" to dst node.
// 1d. For control edges between outside compilation and host computation, // 1c. For control edges between outside compilation and host computation,
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node // remove the edge and add attr "kXlaControlDependenciesAttrName = src node
// name" to dst node. // name" to dst node.
// 2. For data edges between different XLA computations, if either src or dst // 2. For data edges between different XLA computations, if either src or dst
@ -146,26 +158,53 @@ struct XlaClusterInfo {
const std::map<string, int> host_compute_core; const std::map<string, int> host_compute_core;
}; };
// Postprocesses the graph for encapsulation. This function reverts what // Postprocesses edges between different XLA clusters for encapsulation. This
// `PreprocessForEncapsulation` did. It will perform the following operations in // function reverts what `PreprocessForEncapsulation` did. It will perform the
// order: // following operations in order:
// //
// 1. Remove Placeholder nodes between outside compilation and host computation // 1. Remove Placeholder nodes between outside compilation and host computation
// (created in `PreprocessForEncapsulation` step 3). // (created in `PreprocessForEncapsulation` step 3).
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2. // 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
// 3a. Reconnect control edges between different outside compilations (marked by // 3a. Reconnect control edges between outside compilation and another XLA
// `PreprocessForEncapsulation` step 1c) and control edges between outside // computation (marked by `PreprocessForEncapsulation` step 1a).
// compilation and host computation (marked by `PreprocessForEncapsulation` // 3b. Reconnect control edges between different outside compilations (marked by
// step 1d). // `PreprocessForEncapsulation` step 1b).
// 3b. Reconnect control edges between outside compilation and another XLA // 3c. Reconnect control edges between outside compilation and host computation
// computation (marked by `PreprocessForEncapsulation` step 1b). // (marked by `PreprocessForEncapsulation` step 1c).
// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are
// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`.
Status PostprocessForEncapsulation( Status PostprocessForEncapsulation(
Graph* g, const string& xla_computation_attr_name, Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name, const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters); const std::unordered_map<string, XlaClusterInfo>& clusters);
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//
// 0. Remove edges from source node to outside compilation nodes, and edges
// from outside compilation nodes to sink node.
// 1a. For edges between different outside compilation clusters, remove the edge
// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
// name" to dst node.
// 1b. For control edges between outside compilation and its XLA computation,
// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
// outside compilation node.
// 2. For data edges between different outside compilations, remove the edge
// and create a Placeholder node as dst node's input.
Status PreprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name);
// Postprocesses edges within the same XLA cluster. This function reverts what
// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
// following operations in order:
//
// 1. Remove Placeholder nodes between different outside compilations (created
// in `PreprocessEdgesBetweenOutsideCompilations` step 2).
// 2a. Reconnect control edges between different outside compilations (marked by
// `PreprocessEdgesBetweenOutsideCompilations` step 1a).
// Notice that control edges marked by
// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
// They are handled in `RewriteOutsideCompilationSubgraphFn`.
Status PostprocessEdgesBetweenOutsideCompilations(
Graph* g, const string& outside_compilation_attr_name);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ #endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_

View File

@ -107,28 +107,19 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
identity4_node->AddAttr("_xla", "1"); identity4_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_oc", "0"); identity4_node->AddAttr("_oc", "0");
identity5_node->AddAttr("_xla", "1"); identity5_node->AddAttr("_xla", "1");
// Case 1a: control edges between outside compilation and its XLA computation. // Case 1a: control edges between outside compilation and another XLA
g.AddControlEdge(add_node, identity0_node);
g.AddControlEdge(identity0_node, identity1_node);
// Case 1b: control edges between outside compilation and another XLA
// computation. // computation.
g.AddControlEdge(identity0_node, identity3_node); g.AddControlEdge(identity0_node, identity3_node);
g.AddControlEdge(identity1_node, identity4_node); g.AddControlEdge(identity1_node, identity4_node);
// Case 1c: control edges between different outside compilations. // Case 1b: control edges between different outside compilations.
g.AddControlEdge(identity0_node, identity4_node); g.AddControlEdge(identity0_node, identity4_node);
// Case 1d: control edges between outside compilation and host computation. // Case 1c: control edges between outside compilation and host computation.
g.AddControlEdge(const0_node, identity0_node); g.AddControlEdge(const0_node, identity0_node);
g.AddControlEdge(identity0_node, identity2_node); g.AddControlEdge(identity0_node, identity2_node);
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
// Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// outside compilation node.
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
kXlaConnectedFromXlaComputationAttrName));
EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
kXlaConnectedToXlaComputationAttrName));
// Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// to the outside compilation node. // to the outside compilation node.
std::vector<string> attr; std::vector<string> attr;
TF_CHECK_OK(GetNodeAttr(identity0_node->def(), TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
@ -140,13 +131,13 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
kXlaConnectedFromOtherXlaComputationAttrName, &attr)); kXlaConnectedFromOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "0"); EXPECT_EQ(attr[0], "0");
// Case 1c: add attr "_xla_control_deps = src node name" to dst node. // Case 1b: add attr "_xla_control_deps = src node name" to dst node.
attr.clear(); attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(), TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaControlDependenciesAttrName, &attr)); kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1); EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0"); EXPECT_EQ(attr[0], "identity0");
// Case 1d: add attr "_xla_control_deps = src node name" to dst node. // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
attr.clear(); attr.clear();
TF_CHECK_OK(GetNodeAttr(identity0_node->def(), TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaControlDependenciesAttrName, &attr)); kXlaControlDependenciesAttrName, &attr));
@ -162,23 +153,33 @@ TEST(PreprocessForEncapsulationTest, ControlEdges) {
TEST(PreprocessForEncapsulationTest, DataEdges) { TEST(PreprocessForEncapsulationTest, DataEdges) {
// Build the graph: // Build the graph:
// "const_0" and "const_1" in host computation // "const_0" and "const_1" in host computation
// "identityn0" = ("const_0", "const_1") in host computation 0
// "add0" = "const_0" + "const_1" in XLA computation 0 // "add0" = "const_0" + "const_1" in XLA computation 0
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0 // "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
// "identity0" = "add1" in XLA computation 0 // "identity0" = "add1" in XLA computation 0
// "add2" = "add1" + "identity0" in host computation // "add2" = "add1" + "identity0" in host computation
// "add3" = "add1" + "add2" in XLA computation 1 // "add3" = "add1" + "add2" in XLA computation 1
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 1 // "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
// outside compilation 0
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
// outside compilation 0
// "identity1" = "add4" in XLA computation 1 // "identity1" = "add4" in XLA computation 1
// "identity2" = "identity1" in host computation // "identity2" = "identity1" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {}); Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {}); Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
auto identityn0 =
ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1); Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0); Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1); Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0); Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2); Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2); Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
{identityn0[0], identityn0[1]});
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4); Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4); Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
@ -189,6 +190,8 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"], Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
*identity0_node = node_index["identity0"], *identity0_node = node_index["identity0"],
*add3_node = node_index["add3"], *add4_node = node_index["add4"], *add3_node = node_index["add3"], *add4_node = node_index["add4"],
*add5_node = node_index["add5"],
*identityn1_node = node_index["identityn_1"],
*identity1_node = node_index["identity1"]; *identity1_node = node_index["identity1"];
add0_node->AddAttr("_xla", "0"); add0_node->AddAttr("_xla", "0");
add1_node->AddAttr("_xla", "0"); add1_node->AddAttr("_xla", "0");
@ -197,6 +200,10 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
add3_node->AddAttr("_xla", "1"); add3_node->AddAttr("_xla", "1");
add4_node->AddAttr("_xla", "1"); add4_node->AddAttr("_xla", "1");
add4_node->AddAttr("_oc", "0"); add4_node->AddAttr("_oc", "0");
add5_node->AddAttr("_xla", "1");
add5_node->AddAttr("_oc", "0");
identityn1_node->AddAttr("_xla", "1");
identityn1_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "1"); identity1_node->AddAttr("_xla", "1");
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc")); TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
@ -214,8 +221,9 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
EXPECT_NE(bridge_identity0_add4, nullptr); EXPECT_NE(bridge_identity0_add4, nullptr);
// Step 3: add placeholder for edges between host computation and outside // Step 3: add placeholder for edges between host computation and outside
// compilation. // compilation.
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder"); EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
Node *add1_oc_to_host_placeholder = node_index["add1_oc_to_host_placeholder"]; Node *add1_oc_to_host_placeholder =
node_index["add1_oc_to_host_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(), TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostOriginalNodeAttrName, &str)); kOutsideCompilationToHostOriginalNodeAttrName, &str));
EXPECT_EQ(str, "add1"); EXPECT_EQ(str, "add1");
@ -226,15 +234,34 @@ TEST(PreprocessForEncapsulationTest, DataEdges) {
add4_node = node_index["add4"]; add4_node = node_index["add4"];
ASSERT_NE(add4_node, nullptr); ASSERT_NE(add4_node, nullptr);
EXPECT_EQ(add4_node->def().input(0), EXPECT_EQ(add4_node->def().input(0),
"bridge_identity0_add4_host_to_oc_placeholder"); "bridge_identity0_add4_host_to_oc_placeholder_0");
Node *identity0_host_to_oc_placeholder = Node *identity0_host_to_oc_placeholder =
node_index["bridge_identity0_add4_host_to_oc_placeholder"]; node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationOriginalNodeAttrName, &str)); kHostToOutsideCompilationOriginalNodeAttrName, &str));
EXPECT_EQ(str, "bridge_identity0_add4"); EXPECT_EQ(str, "bridge_identity0_add4");
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(), TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationSrcOutputAttrName, &i)); kHostToOutsideCompilationSrcOutputAttrName, &i));
EXPECT_EQ(i, 0); EXPECT_EQ(i, 0);
// Check different placeholder nodes are created for different src_output.
Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
*placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
EXPECT_NE(placeholder0, nullptr);
EXPECT_NE(placeholder1, nullptr);
// Check we only have 2 placeholder nodes created for "identityn_0".
int placeholder_count = 0;
for (Node *n : g.nodes()) {
if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
string attr;
TF_CHECK_OK(GetNodeAttr(
n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
if (attr == "identityn_0") {
++placeholder_count;
}
}
}
EXPECT_EQ(placeholder_count, 2);
} }
TEST(PostprocessForEncapsulationTest, ControlEdges) { TEST(PostprocessForEncapsulationTest, ControlEdges) {

View File

@ -195,8 +195,11 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) { e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Undeclared output of XLA computation. A common cause of this error " "Undeclared output of XLA computation. Some common causes of this "
"is variable initializers that depend on the XLA computation. Edge: ", "error are: 1) variable initializers that depend on the XLA "
"computation; 2) gradient computations that depend on the XLA "
"computation, which can be mitigated by moving gradient computations "
"inside XLA computation. Offending edge: ",
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
e->dst_input()); e->dst_input());
} }

View File

@ -366,7 +366,7 @@ Status ReplaceOrRemoveOutsideCompilationCallNode(
// replace this node with compilation result node. // replace this node with compilation result node.
// 3) all outside compilation graphs. // 3) all outside compilation graphs.
Status ConstructHostGraph( Status ConstructHostGraph(
const string& xla_cluster_name, const string& xla_cluster_name, const string& outside_compilation_attr_name,
const std::vector<string>& outside_compilation_host_graphs, const std::vector<string>& outside_compilation_host_graphs,
FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) { FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
host_graph->reset(new Graph(fld)); host_graph->reset(new Graph(fld));
@ -476,6 +476,10 @@ Status ConstructHostGraph(
host_graph->get(), host_graph->get(),
std::unordered_set<const Node*>{(*host_graph)->sink_node()}); std::unordered_set<const Node*>{(*host_graph)->sink_node()});
// Postprocess edges between different outside compilations.
TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
host_graph->get(), outside_compilation_attr_name));
if (VLOG_IS_ON(4)) { if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile( dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_host_graph_for_", absl::StrCat("extract_outside_compilation_host_graph_for_",
@ -801,6 +805,11 @@ Status ExtractOutsideCompilationForFunction(
}, },
&fbody)); &fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody); std::unique_ptr<FunctionBody> fbody_deleter(fbody);
// Preprocess edges between different outside compilations. They will be
// restored in `ConstructHostGraph()`.
TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
fbody->graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) { if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile( dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_for_func_before_", func_name), absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
@ -860,8 +869,9 @@ Status ExtractOutsideCompilationForFunction(
// Construct host graph. // Construct host graph.
if (!outside_compilation_host_graphs.empty()) { if (!outside_compilation_host_graphs.empty()) {
TF_RETURN_IF_ERROR(ConstructHostGraph( TF_RETURN_IF_ERROR(
xla_cluster_name, outside_compilation_host_graphs, fld, host_graph)); ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
outside_compilation_host_graphs, fld, host_graph));
} }
// Remove the outside compilation graphs from function library. // Remove the outside compilation graphs from function library.

View File

@ -290,21 +290,18 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes)); TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
EXPECT_EQ(shapes.size(), 1); EXPECT_EQ(shapes.size(), 1);
EXPECT_EQ(shapes[0].dim_size(), 1); EXPECT_EQ(shapes[0].dim_size(), 1);
// Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
// non-empty value, and "1" should have an empty value. // empty values.
string shape_inference_graph; string shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph", TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
&shape_inference_graph)); &shape_inference_graph));
EXPECT_EQ(shape_inference_graph, EXPECT_EQ(shape_inference_graph, "");
"_outside_compilation_shape_inference_cluster_0");
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph", TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
&shape_inference_graph)); &shape_inference_graph));
EXPECT_EQ(shape_inference_graph, ""); EXPECT_EQ(shape_inference_graph, "");
// Check `shape_inference_graphs`. // Check `shape_inference_graphs`.
EXPECT_EQ(shape_inference_graphs.size(), 1); EXPECT_EQ(shape_inference_graphs.size(), 0);
EXPECT_EQ(shape_inference_graphs[0],
"_outside_compilation_shape_inference_cluster_0");
// Check `host_graph`: verify we have key placeholder and sequencer. // Check `host_graph`: verify we have key placeholder and sequencer.
Node *key_placeholder = nullptr, *sequencer = nullptr; Node *key_placeholder = nullptr, *sequencer = nullptr;
@ -333,8 +330,8 @@ TEST(ExtractOutsideCompilationForFunctionTest, Basic) {
send_recv_nodes.push_back(n); send_recv_nodes.push_back(n);
} }
} }
EXPECT_EQ(num_send_from_host, 2); EXPECT_EQ(num_send_from_host, 1);
EXPECT_EQ(num_recv_at_host, 2); EXPECT_EQ(num_recv_at_host, 1);
for (Node *n : send_recv_nodes) { for (Node *n : send_recv_nodes) {
Node *input_node; Node *input_node;
TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node)); TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));

View File

@ -0,0 +1,152 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <mutex> // NOLINT
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
BuildXlaOpsPassFlags* build_ops_flags;
DumpGraphFlags* dump_graph_flags;
MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
void AppendDumpGraphFlagsInternal(std::vector<Flag>* flag_list) {
std::vector<Flag> new_flags = {
Flag("tf_dump_graph_prefix", &dump_graph_flags->tf_dump_graph_prefix,
"Path prefix to which graphs dumped during debugging should be "
"written."),
};
flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
}
void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
std::vector<Flag> new_flags = {
Flag("tf_xla_auto_jit", &mark_for_compilation_flags->tf_xla_auto_jit,
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"Experimental."),
Flag("tf_xla_min_cluster_size",
&mark_for_compilation_flags->tf_xla_min_cluster_size,
"Minimum number of operators in an XLA compilation. Ignored for "
"operators placed on an XLA device or operators explicitly marked "
"for compilation."),
Flag("tf_xla_max_cluster_size",
&mark_for_compilation_flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_clustering_debug",
&mark_for_compilation_flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Flag("tf_xla_cpu_global_jit",
&mark_for_compilation_flags->tf_xla_cpu_global_jit,
"Enables global JIT compilation for CPU via SessionOptions."),
Flag("tf_xla_clustering_fuel",
&mark_for_compilation_flags->tf_xla_clustering_fuel,
"Places an artificial limit on the number of ops marked as "
"eligible for clustering."),
Flag("tf_xla_fusion_only",
&mark_for_compilation_flags->tf_xla_fusion_only,
"enable fusion of element-wise operations only using XLA when "
"global_jit_level is ON*.")};
flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
}
void AllocateAndParseFlags() {
build_ops_flags = new BuildXlaOpsPassFlags;
build_ops_flags->tf_xla_enable_lazy_compilation = true;
dump_graph_flags = new DumpGraphFlags;
dump_graph_flags->tf_dump_graph_prefix = "/tmp/";
mark_for_compilation_flags = new MarkForCompilationPassFlags;
mark_for_compilation_flags->tf_xla_auto_jit = 0;
mark_for_compilation_flags->tf_xla_min_cluster_size = 2;
mark_for_compilation_flags->tf_xla_max_cluster_size =
std::numeric_limits<int32>::max();
mark_for_compilation_flags->tf_xla_clustering_debug = false;
mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
mark_for_compilation_flags->tf_xla_clustering_fuel =
std::numeric_limits<int64>::max();
mark_for_compilation_flags->tf_xla_fusion_only = false;
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_enable_lazy_compilation",
&build_ops_flags->tf_xla_enable_lazy_compilation, ""),
Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
Flag("tf_xla_always_defer_compilation",
&ops_flags->tf_xla_always_defer_compilation, ""),
});
AppendDumpGraphFlagsInternal(flag_list);
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
}
} // namespace
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return *build_ops_flags;
}
DumpGraphFlags* GetDumpGraphFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return dump_graph_flags;
}
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return mark_for_compilation_flags;
}
XlaDeviceFlags* GetXlaDeviceFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return device_flags;
}
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return *ops_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
void AppendDumpGraphFlags(std::vector<Flag>* flag_list) {
std::call_once(flags_init, &AllocateAndParseFlags);
AppendDumpGraphFlagsInternal(flag_list);
}
} // namespace tensorflow

View File

@ -13,10 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ #ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ #define TENSORFLOW_COMPILER_JIT_FLAGS_H_
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
#include <vector> #include <vector>
@ -24,15 +22,8 @@ limitations under the License.
#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow { namespace tensorflow {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with the XLA bridge's // Flags associated with the XLA bridge's mark_for_compilation_pass module.
// mark_for_compilation_pass module.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with the XLA bridge's
// mark_for_compilation_pass module.
struct MarkForCompilationPassFlags { struct MarkForCompilationPassFlags {
int32 tf_xla_auto_jit; // Control compilation of operators into XLA int32 tf_xla_auto_jit; // Control compilation of operators into XLA
// computations on CPU and GPU devices. 0 = use // computations on CPU and GPU devices. 0 = use
@ -57,12 +48,56 @@ struct MarkForCompilationPassFlags {
// only using XLA. // only using XLA.
}; };
// Return a pointer to the MarkForCompilationPassFlags struct; // Flags associated with the XLA bridge's xla_device module.
struct XlaDeviceFlags {
// Switch the CPU device into "on-demand" mode, where instead of
// autoclustering ops are compiled one by one just-in-time.
// Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand;
};
// Flags common to the _Xla* ops and their kernels.
struct XlaOpsCommonFlags {
// If true, _XlaCompile always refuses to compile the cluster, which means the
// XLA clusters always run in the TF executor. Defaults to false.
bool tf_xla_always_defer_compilation;
};
// Flags for the build_xla_ops pass.
struct BuildXlaOpsPassFlags {
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
// Defaults to true.
bool tf_xla_enable_lazy_compilation;
};
// Flags for the XLA bridge's dump_graph module.
struct DumpGraphFlags {
// Path prefix to which graphs dumped during debugging should be written.
string tf_dump_graph_prefix;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer. // repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned. // This should be called only after Flags::Parse() has returned.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
} // namespace legacy_flags // Getters for flags structs defined above. The first call to any of these
// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer
// always return the same pointer.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
XlaDeviceFlags* GetXlaDeviceFlags();
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
DumpGraphFlags* GetDumpGraphFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//
// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet.
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
void AppendDumpGraphFlags(std::vector<tensorflow::Flag>* flag_list);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_ #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/dump_graph.h"
@ -208,8 +208,12 @@ Status ComputeSliceSize(const Scope& host_scope,
DCHECK_EQ(slice_size.back().type(), DT_INT64); DCHECK_EQ(slice_size.back().type(), DT_INT64);
} }
*size = ops::Concat(host_scope.WithOpName("slice_size"), slice_size, // Trivial ConcatV2 nodes (with exactly one input) are disallowed.
ops::Const(host_scope.WithOpName("concat_axis"), 0)); *size =
slice_size.size() == 1
? slice_size[0]
: ops::Concat(host_scope.WithOpName("slice_size"), slice_size,
ops::Const(host_scope.WithOpName("concat_axis"), 0));
return Status::OK(); return Status::OK();
} }
@ -242,6 +246,9 @@ Status ConvertTensorFlowSliceToStaticShapedSlice(
.WithOpName("static_shaped_slice"), .WithOpName("static_shaped_slice"),
slice_inputs_int64.input, slice_inputs_int64.begin, slice_size) slice_inputs_int64.input, slice_inputs_int64.begin, slice_size)
.node(); .node();
TF_RETURN_IF_ERROR(main_scope.status());
std::vector<string> compile_time_const_inputs; std::vector<string> compile_time_const_inputs;
compile_time_const_inputs.push_back("size"); compile_time_const_inputs.push_back("size");
(*result)->AddAttr(kXlaCompileTimeConstantInputsAttr, (*result)->AddAttr(kXlaCompileTimeConstantInputsAttr,
@ -284,49 +291,45 @@ Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs,
return Status::OK(); return Status::OK();
} }
// If `n` is a slice we can rewrite to have a static shape (i.e. have the output // Return true if `n` is a slice we can rewrite to have a static shape
// shape only depend on the "size" input) then returns the a SliceInputs // (i.e. have the output shape only depend on the "size" input).
// representing the inputs to `n`. Otherwise returns nullopt. xla::StatusOr<bool> IsRewritableSlice(Node* n) {
StatusOrOptional<SliceInputs> IsRewritableSlice(Node* n) {
if (n->type_string() != "Slice") { if (n->type_string() != "Slice") {
return {absl::nullopt}; return false;
} }
if (!GetXlaClusterForNode(*n).has_value()) { if (!GetXlaClusterForNode(*n).has_value()) {
// There is no need to change slice ops outside XLA clusters. // There is no need to change slice ops outside XLA clusters.
return {absl::nullopt}; return false;
} }
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs, TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
GetSliceInputs(n)); GetSliceInputs(n));
if (!slice_inputs.has_value()) { if (!slice_inputs.has_value()) {
return {absl::nullopt}; return false;
} }
// If slice_size[i] < -1 for any i then executing the slice will throw an // If slice_size[i] < -1 for any i then executing the slice will throw an
// error, and we don't do anything here. // error, and we don't do anything here.
bool slice_is_ok = absl::c_all_of(slice_inputs->size_as_vector, return absl::c_all_of(slice_inputs->size_as_vector,
[](int64 size_i) { return size_i >= -1; }); [](int64 size_i) { return size_i >= -1; });
if (!slice_is_ok) {
return {absl::nullopt};
}
return slice_inputs;
} }
Status FindAndRewriteSlices(Graph* g, bool* changed) { Status FindAndRewriteSlices(Graph* g, bool* changed) {
std::vector<std::pair<Node*, SliceInputs>> slices_to_rewrite; std::vector<Node*> slices_to_rewrite;
for (Node* n : g->nodes()) { for (Node* n : g->nodes()) {
TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs, TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n));
IsRewritableSlice(n)); if (is_rewritable) {
if (slice_inputs.has_value()) { slices_to_rewrite.push_back(n);
slices_to_rewrite.push_back({n, std::move(*slice_inputs)});
} }
} }
for (const auto& pair : slices_to_rewrite) { for (Node* n : slices_to_rewrite) {
TF_RETURN_IF_ERROR(RewriteSlice(g, pair.first, pair.second, TF_ASSIGN_OR_RETURN(absl::optional<SliceInputs> slice_inputs,
*GetXlaClusterForNode(*pair.first))); GetSliceInputs(n));
TF_RET_CHECK(slice_inputs.has_value());
TF_RETURN_IF_ERROR(
RewriteSlice(g, n, *slice_inputs, *GetXlaClusterForNode(*n)));
} }
if (!slices_to_rewrite.empty()) { if (!slices_to_rewrite.empty()) {
@ -342,8 +345,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) {
Status IncreaseDynamismForAutoJitPass::Run( Status IncreaseDynamismForAutoJitPass::Run(
const GraphOptimizationPassOptions& options) { const GraphOptimizationPassOptions& options) {
legacy_flags::MarkForCompilationPassFlags* flags = MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
legacy_flags::GetMarkForCompilationPassFlags();
if (flags->tf_xla_clustering_debug) { if (flags->tf_xla_clustering_debug) {
dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass", dump_graph::DumpGraphToFile("before_increase_dynamism_for_auto_jit_pass",
**options.graph, options.flib_def); **options.graph, options.flib_def);

View File

@ -27,6 +27,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
using ::testing::_;
using testing::matchers::AssignedDevice; using testing::matchers::AssignedDevice;
using testing::matchers::Attr; using testing::matchers::Attr;
using testing::matchers::Const; using testing::matchers::Const;
@ -142,6 +143,26 @@ TEST(SliceToDynamicSliceRewriteTest, Basic) {
EXPECT_THAT(static_shaped_slice, m_dynamic_slice); EXPECT_THAT(static_shaped_slice, m_dynamic_slice);
} }
TEST(SliceToDynamicSliceRewriteTest, SliceFromVector) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
Output size = ops::Const(root.WithOpName("size"), {-1});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
Node* static_shaped_slice = testing::FindNodeByName(
result.get(), "slice/static_shaped_slice/static_shaped_slice");
EXPECT_NE(static_shaped_slice, nullptr);
EXPECT_THAT(result->nodes(), Not(Contains(NodeWith(Op("ConcatV2")))));
}
TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) { TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
Scope root = Scope::NewRootScope() Scope root = Scope::NewRootScope()
.ExitOnError() .ExitOnError()
@ -166,18 +187,18 @@ TEST(SliceToDynamicSliceRewriteTest, ControlDependencePreserved) {
CtrlDeps(NodeWith(Op("Placeholder"), Name("control"))))); CtrlDeps(NodeWith(Op("Placeholder"), Name("control")))));
} }
int64 ToInt64(int v) { return static_cast<int64>(v); }
TEST(SliceToDynamicSliceRewriteTest, Int64Indices) { TEST(SliceToDynamicSliceRewriteTest, Int64Indices) {
Scope root = Scope::NewRootScope() Scope root = Scope::NewRootScope()
.ExitOnError() .ExitOnError()
.WithAssignedDevice(kDeviceName) .WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0"); .WithXlaCluster("cluster_0");
auto to_int64 = [](int v) { return static_cast<int64>(v); };
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
Output size = Output size =
ops::Const(root.WithOpName("size"), {to_int64(-1), to_int64(500)}); ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(500)});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size); Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result; std::unique_ptr<Graph> result;
@ -252,13 +273,35 @@ TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithNonConstSize) {
Attr(kXlaCompileTimeConstantInputsAttr))))); Attr(kXlaCompileTimeConstantInputsAttr)))));
} }
TEST(SliceToDynamicSliceRewriteTest, ScalarSlice) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
Output size = ops::Const<int64>(root.WithOpName("size"), {});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
Node* static_shaped_slice = testing::FindNodeByName(
result.get(), "slice/static_shaped_slice/static_shaped_slice");
ASSERT_NE(static_shaped_slice, nullptr);
EXPECT_THAT(static_shaped_slice,
NodeWith(Op("Slice"), Attr(kXlaCompileTimeConstantInputsAttr),
Inputs(_, _, Out(NodeWith(Name(size.node()->name()))))));
}
TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) { TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
Scope root = Scope::NewRootScope() Scope root = Scope::NewRootScope()
.ExitOnError() .ExitOnError()
.WithAssignedDevice(kDeviceName) .WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0"); .WithXlaCluster("cluster_0");
auto to_int64 = [](int v) { return static_cast<int64>(v); }; auto ToInt64 = [](int v) { return static_cast<int64>(v); };
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT); Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64); Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT64);
@ -271,7 +314,7 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder); ops::Slice(root.WithOpName("slice"), input, begin, size_placeholder);
Output size = Output size =
ops::Const(root.WithOpName("size"), {{to_int64(-1)}, {to_int64(500)}}); ops::Const(root.WithOpName("size"), {{ToInt64(-1)}, {ToInt64(500)}});
TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2)); TF_ASSERT_OK(root.graph()->UpdateEdge(size.node(), 0, slice.node(), 2));
std::unique_ptr<Graph> result; std::unique_ptr<Graph> result;
@ -281,5 +324,82 @@ TEST(SliceToDynamicSliceRewriteTest, IndicesNotVector) {
Not(Contains(NodeWith(Op("Slice"), Not(Contains(NodeWith(Op("Slice"),
Attr(kXlaCompileTimeConstantInputsAttr))))); Attr(kXlaCompileTimeConstantInputsAttr)))));
} }
TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceInput) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
Output begin = ops::Placeholder(root.WithOpName("begin"), DT_INT32);
Output size_a = ops::Const(root.WithOpName("size_a"), {-1, 500});
Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size_a);
Output size_b = ops::Const(root.WithOpName("size_a"), {-1, 200});
Output slice_with_slice_input = ops::Slice(
root.WithOpName("slice_with_slice_input"), slice, begin, size_b);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
Node* static_shaped_slice = testing::FindNodeByName(
result.get(),
"slice_with_slice_input/static_shaped_slice/static_shaped_slice");
ASSERT_NE(static_shaped_slice, nullptr);
EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
<< "Expected DT_FLOAT, was "
<< DataType_Name(static_shaped_slice->output_type(0));
EXPECT_THAT(
static_shaped_slice,
NodeWith(
Op("Slice"),
Inputs(Out(NodeWith(
Op("Slice"),
Name("slice/static_shaped_slice/static_shaped_slice"))),
_, _)));
}
TEST(SliceToDynamicSliceRewriteTest, SliceWithSliceBegin) {
Scope root = Scope::NewRootScope()
.ExitOnError()
.WithAssignedDevice(kDeviceName)
.WithXlaCluster("cluster_0");
Output input_float =
ops::Placeholder(root.WithOpName("input_float"), DT_FLOAT);
Output input_i64 = ops::Placeholder(root.WithOpName("input_i64"), DT_INT64);
Output begin_begin =
ops::Placeholder(root.WithOpName("begin_begin"), DT_INT32);
Output begin_size = ops::Const(root.WithOpName("begin_size"), {-1});
Output begin =
ops::Slice(root.WithOpName("begin"), input_i64, begin_begin, begin_size);
Output size =
ops::Const(root.WithOpName("size"), {ToInt64(-1), ToInt64(200)});
Output slice_with_slice_begin = ops::Slice(
root.WithOpName("slice_with_slice_begin"), input_float, begin, size);
std::unique_ptr<Graph> result;
TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
Node* static_shaped_slice = testing::FindNodeByName(
result.get(),
"slice_with_slice_begin/static_shaped_slice/static_shaped_slice");
ASSERT_NE(static_shaped_slice, nullptr);
EXPECT_EQ(static_shaped_slice->output_type(0), DT_FLOAT)
<< "Expected DT_FLOAT, was "
<< DataType_Name(static_shaped_slice->output_type(0));
EXPECT_THAT(
static_shaped_slice,
NodeWith(
Op("Slice"),
Inputs(_,
Out(NodeWith(
Op("Slice"),
Name("begin/static_shaped_slice/static_shaped_slice"))),
_)));
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -12,10 +12,10 @@ cc_library(
hdrs = ["xla_ops.h"], hdrs = ["xla_ops.h"],
deps = [ deps = [
"//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:xla_compilation_cache", "//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/jit/legacy_flags:xla_ops_common_flags",
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -418,7 +418,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
cannot_compile_cluster = cannot_compile_cluster_; cannot_compile_cluster = cannot_compile_cluster_;
} }
if (legacy_flags::GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation ||
cannot_compile_cluster) { cannot_compile_cluster) {
executable = nullptr; executable = nullptr;
} else { } else {

View File

@ -1,65 +0,0 @@
# Legacy command line flags for the XLA bridge libraries.
# Please do not add more flags to this package.
# The XLA bridge libraries were written in an environment that allowed
# command-line flags to be scattered freely throughout the libraries. This
# model, while initially convenient, leads to a proliferation in unused command
# line flags in tests and binaries, and serious problems in servers, where one
# might wish parameters to be different in independent RPC calls to the same
# routine.
#
# Please don't add more flags. If you're a library author, pass options and
# parameters explicitly through the library's interface.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
cc_library(
name = "mark_for_compilation_pass_flags",
srcs = ["mark_for_compilation_pass_flags.cc"],
hdrs = ["mark_for_compilation_pass_flags.h"],
deps =
[
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "xla_device_flags",
srcs = ["xla_device_flags.cc"],
hdrs = ["xla_device_flags.h"],
deps =
[
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "build_xla_ops_pass_flags",
srcs = ["build_xla_ops_pass_flags.cc"],
hdrs = ["build_xla_ops_pass_flags.h"],
deps =
[
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
cc_library(
name = "xla_ops_common_flags",
srcs = ["xla_ops_common_flags.cc"],
hdrs = ["xla_ops_common_flags.h"],
deps =
[
"//tensorflow/compiler/xla:parse_flags_from_env",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)

View File

@ -1,47 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <mutex> // NOLINT
#include "tensorflow/compiler/jit/legacy_flags/build_xla_ops_pass_flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
namespace {
BuildXlaOpsPassFlags* flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
void AllocateAndParseFlags() {
flags = new BuildXlaOpsPassFlags;
flags->tf_xla_enable_lazy_compilation = true;
flag_list = new std::vector<Flag>({
Flag("tf_xla_enable_lazy_compilation",
&flags->tf_xla_enable_lazy_compilation, ""),
});
xla::ParseFlagsFromEnv(*flag_list);
}
} // namespace
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return *flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -1,37 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_
namespace tensorflow {
namespace legacy_flags {
// Flags for the build_xla_ops pass.
struct BuildXlaOpsPassFlags {
// Enables lazy compilation for TF/XLA (only when auto-clustering) if true.
// Defaults to true.
bool tf_xla_enable_lazy_compilation;
};
// Parses the flags in BuildXlaOpsPassFlags from the TF_XLA_FLAGS environment
// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS
// only the first time this routine is called.
const BuildXlaOpsPassFlags& GetBuildXlaOpsPassFlags();
} // namespace legacy_flags
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_BUILD_XLA_OPS_PASS_FLAGS_H_

View File

@ -1,98 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
#include <mutex>
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static MarkForCompilationPassFlags* flags;
static std::vector<Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new MarkForCompilationPassFlags;
flags->tf_xla_auto_jit = 0;
flags->tf_xla_min_cluster_size = 2;
flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
flags->tf_xla_clustering_debug = false;
flags->tf_xla_cpu_global_jit = false;
flags->tf_xla_clustering_fuel = std::numeric_limits<int64>::max();
flags->tf_xla_fusion_only = false;
flag_list = new std::vector<Flag>(
{Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
"Control compilation of operators into XLA computations on CPU and "
"GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
"things very likely to be improved; 2 = on for everything. "
"Experimental."),
Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
"Minimum number of operators in an XLA compilation. Ignored for "
"operators placed on an XLA device or operators explicitly marked "
"for compilation."),
Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
"Maximum number of operators in an XLA compilation."),
Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit,
"Enables global JIT compilation for CPU via SessionOptions."),
Flag("tf_xla_clustering_fuel", &flags->tf_xla_clustering_fuel,
"Places an artificial limit on the number of ops marked as "
"eligible for clustering."),
Flag("tf_xla_fusion_only", &flags->tf_xla_fusion_only,
"enable fusion of element-wise operations only using XLA when "
"global_jit_level is ON*.")});
xla::ParseFlagsFromEnv(*flag_list);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Parsed MarkForCompilationPassFlags:";
VLOG(1) << " tf_xla_auto_jit = " << flags->tf_xla_auto_jit;
VLOG(1) << " tf_xla_min_cluster_size = " << flags->tf_xla_min_cluster_size;
VLOG(1) << " tf_xla_max_cluster_size = " << flags->tf_xla_max_cluster_size;
VLOG(1) << " tf_xla_clustering_debug = " << flags->tf_xla_clustering_debug;
VLOG(1) << " tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit;
VLOG(1) << " tf_xla_clustering_fuel = " << flags->tf_xla_clustering_fuel;
VLOG(1) << " tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
}
}
// Append to *append_to flag definitions associated with the XLA bridge's
// mark_for_compilation_pass module.
void AppendMarkForCompilationPassFlags(std::vector<Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the MarkForCompilationPassFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -1,56 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Legacy flags for the XLA bridge's xla_device module.
#include <mutex>
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static XlaDeviceFlags* flags;
static std::vector<Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new XlaDeviceFlags;
flags->tf_xla_compile_on_demand = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_compile_on_demand", &flags->tf_xla_compile_on_demand,
"Switch a device into 'on-demand' mode, where instead of "
"autoclustering ops are compiled one by one just-in-time."),
});
xla::ParseFlagsFromEnv(*flag_list);
}
// Return a pointer to the XlaDeviceFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
XlaDeviceFlags* GetXlaDeviceFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -1,47 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_
// Legacy flags for the XLA bridge's xla_device module.
#include <vector>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
// The values of flags associated with the XLA bridge's
// xla_device module.
typedef struct {
// Switch the CPU device into "on-demand" mode, where instead of
// autoclustering ops are compiled one by one just-in-time.
// Enabling this mode by a legacy flag is a temporary mechanism. When this
// feature is battle-tested, we will switch this to be a session option.
bool tf_xla_compile_on_demand;
} XlaDeviceFlags;
// Return a pointer to the XlaDeviceFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
XlaDeviceFlags* GetXlaDeviceFlags();
} // namespace legacy_flags
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_DEVICE_FLAGS_H_

View File

@ -1,52 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <mutex> // NOLINT
#include <vector>
#include "tensorflow/compiler/jit/legacy_flags/xla_ops_common_flags.h"
#include "tensorflow/compiler/xla/parse_flags_from_env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace legacy_flags {
XlaOpsCommonFlags* flags;
std::vector<Flag>* flag_list;
std::once_flag flags_init;
void AllocateAndParseFlags() {
flags = new XlaOpsCommonFlags;
flags->tf_xla_always_defer_compilation = false;
flag_list = new std::vector<Flag>({
Flag("tf_xla_always_defer_compilation",
&flags->tf_xla_always_defer_compilation, ""),
});
xla::ParseFlagsFromEnv(*flag_list);
if (VLOG_IS_ON(1)) {
VLOG(1) << "Parsed XlaOpsCommonFlags:";
VLOG(1) << " tf_xla_always_defer_compilation = "
<< flags->tf_xla_always_defer_compilation;
}
}
const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
std::call_once(flags_init, &AllocateAndParseFlags);
return *flags;
}
} // namespace legacy_flags
} // namespace tensorflow

View File

@ -1,36 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_
namespace tensorflow {
namespace legacy_flags {
// Flags common to the _Xla* ops and their kernels.
struct XlaOpsCommonFlags {
// If true, _XlaCompile always refuses to compile the cluster, which means the
// XLA clusters always run in the TF executor. Defaults to false.
bool tf_xla_always_defer_compilation;
};
// Parses the flags in XlaOpsCommonFlags from the TF_XLA_FLAGS environment
// variable and returns a reference to the parsed copy. Parses TF_XLA_FLAGS
// only the first time this routine is called.
const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
} // namespace legacy_flags
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_XLA_OPS_COMMON_FLAGS_H_

View File

@ -24,8 +24,8 @@ limitations under the License.
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h"
@ -72,6 +72,11 @@ struct OperationFilter {
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops // to resort to a dummy implementation. Currently Assert and CheckNumerics ops
// have dummy XLA implementations. // have dummy XLA implementations.
bool allow_dummy_ops; bool allow_dummy_ops;
// Whether ops that produce or consume DT_VARIANT values are allowed. We
// don't auto-cluster these ops because we don't yet support live-in or
// live-out DT_VARIANT values.
bool allow_ops_producing_or_consuming_variant;
}; };
bool IsDummyImplOp(absl::string_view op_name) { bool IsDummyImplOp(absl::string_view op_name) {
@ -81,7 +86,13 @@ bool IsDummyImplOp(absl::string_view op_name) {
bool IsStatefulRandomOp(absl::string_view op_name) { bool IsStatefulRandomOp(absl::string_view op_name) {
return op_name == "RandomUniform" || op_name == "RandomShuffle" || return op_name == "RandomUniform" || op_name == "RandomShuffle" ||
op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" ||
op_name == "TruncatedNormal"; op_name == "TruncatedNormal" || op_name == "Multinomial";
}
bool OpProducesOrConsumesVariant(const Node& node) {
auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
return absl::c_any_of(node.input_types(), is_variant) ||
absl::c_any_of(node.output_types(), is_variant);
} }
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
@ -246,6 +257,10 @@ bool IsCompilableCall(const NodeDef& call_def,
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) { if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
return false; return false;
} }
if (!op_filter.allow_ops_producing_or_consuming_variant &&
OpProducesOrConsumesVariant(*node)) {
return false;
}
if (!HasXLAKernel(*node, jit_device_type) && if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1, !IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
lib_runtime)) { lib_runtime)) {
@ -427,8 +442,7 @@ Status FindCompilationCandidates(
BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr,
&compile_time_const_nodes)); &compile_time_const_nodes));
int64& fuel = int64& fuel = GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel;
// Iterate over nodes in sorted order so that compiler fuel is deterministic. // Iterate over nodes in sorted order so that compiler fuel is deterministic.
// We can't simply pass op_nodes().begin() and op_nodes().end to the // We can't simply pass op_nodes().begin() and op_nodes().end to the
@ -471,16 +485,15 @@ Status FindCompilationCandidates(
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)); XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration));
DeviceType jit_device_type(registration->compilation_device_name); DeviceType jit_device_type(registration->compilation_device_name);
bool always_auto_cluster = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
OperationFilter op_filter; OperationFilter op_filter;
op_filter.allow_resource_ops = registration->compile_resource_ops; op_filter.allow_resource_ops = registration->compile_resource_ops;
op_filter.allow_stateful_rng_ops = op_filter.allow_stateful_rng_ops = always_auto_cluster;
(registration->autoclustering_policy == op_filter.allow_control_trigger = always_auto_cluster;
XlaOpRegistry::AutoclusteringPolicy::kAlways); op_filter.allow_dummy_ops = always_auto_cluster;
op_filter.allow_control_trigger = op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster;
(registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
op_filter.allow_dummy_ops = (registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways);
if (!HasXLAKernel(*node, jit_device_type) && if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0, !IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
@ -504,6 +517,12 @@ Status FindCompilationCandidates(
<< node->type_string() << ")"; << node->type_string() << ")";
continue; continue;
} }
if (!op_filter.allow_ops_producing_or_consuming_variant &&
OpProducesOrConsumesVariant(*node)) {
VLOG(2) << "Rejecting " << node->name()
<< ": produces or consumes DT_VARIANT";
continue;
}
if (!op_filter.allow_resource_ops && if (!op_filter.allow_resource_ops &&
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
@ -607,8 +626,7 @@ OptimizerOptions::GlobalJitLevel GetGlobalJitLevel(
// To set compilation to be on by default, change the following line. // To set compilation to be on by default, change the following line.
global_jit_level = OptimizerOptions::OFF; global_jit_level = OptimizerOptions::OFF;
} }
legacy_flags::MarkForCompilationPassFlags* flags = MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
legacy_flags::GetMarkForCompilationPassFlags();
if (flags->tf_xla_auto_jit == -1 || if (flags->tf_xla_auto_jit == -1 ||
(1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) { (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
// If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
@ -641,6 +659,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
op_filter.allow_stateful_rng_ops = true; op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true; op_filter.allow_control_trigger = true;
op_filter.allow_dummy_ops = true; op_filter.allow_dummy_ops = true;
op_filter.allow_ops_producing_or_consuming_variant = true;
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr); return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
} }
@ -651,8 +670,7 @@ Status MarkForCompilationPass::Run(
// device ahead of time. // device ahead of time.
OptimizerOptions::GlobalJitLevel global_jit_level = OptimizerOptions::GlobalJitLevel global_jit_level =
GetGlobalJitLevel(options); GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags = MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
legacy_flags::GetMarkForCompilationPassFlags();
bool fusion_only = flags->tf_xla_fusion_only; bool fusion_only = flags->tf_xla_fusion_only;
VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only; VLOG(1) << "flags->tf_xla_fusion_only = " << flags->tf_xla_fusion_only;
@ -953,8 +971,7 @@ Status MarkForCompilationPass::RunImpl(
OptimizerOptions::GlobalJitLevel global_jit_level = OptimizerOptions::GlobalJitLevel global_jit_level =
GetGlobalJitLevel(options); GetGlobalJitLevel(options);
legacy_flags::MarkForCompilationPassFlags* flags = MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
legacy_flags::GetMarkForCompilationPassFlags();
// Repeatedly contract edges between clusters that are on the same device, // Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle. // provided the contraction would not create a cycle.

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
@ -1147,5 +1148,80 @@ TEST(XlaCompilationTest, DontAutoClusterDummyOps) {
EXPECT_EQ(clusters["test/check"], ""); EXPECT_EQ(clusters["test/check"], "");
} }
TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
Output tensor_list_reserve = ops::TensorListReserve(
root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
}
TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
Scope root = Scope::NewRootScope().ExitOnError();
Output dummy_input =
ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
Output variant_input =
ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
// Create one more node so that we don't avoid creating a cluster solely
// because it would be trivial.
Output dummy_cast =
ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
Output tensor_list_element_shape = ops::TensorListElementShape(
root.WithOpName("test/tensor_list_element_shape"), variant_input,
DT_INT32);
root.graph()->AddControlEdge(dummy_cast.node(),
tensor_list_element_shape.node());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
}
TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
Output tensor_list_reserve = ops::TensorListReserve(
root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
for (Node* n : graph->nodes()) {
if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
n->set_assigned_device_name(xla_cpu_device);
}
}
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
EXPECT_NE(clusters["test/tensor_list_reserve"], "");
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,15 +34,9 @@ namespace tensorflow {
// //
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
// make this more direct, but probably not worth it solely for this test. // make this more direct, but probably not worth it solely for this test.
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
auto delete_devices = gtl::MakeCleanup([&] {
for (Device* d : devices) {
delete d;
}
});
GraphOptimizationPassOptions opt_options; GraphOptimizationPassOptions opt_options;
opt_options.graph = graph; opt_options.graph = graph;
opt_options.session_options = session_options; opt_options.session_options = session_options;

View File

@ -18,3 +18,9 @@ tf_gen_op_wrapper_py(
out = "xla_ops.py", out = "xla_ops.py",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"], deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
) )
py_library(
name = "xla_ops_grad",
srcs = ["xla_ops_grad.py"],
deps = ["//tensorflow/python:framework_ops"],
)

View File

@ -1,3 +1,4 @@
"""Gradients for XLA ops."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,21 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""dnn python module.
Importing from tensorflow.python.estimator is unsupported
and will soon break!
"""
# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow_estimator.contrib.estimator.python.estimator import dnn from tensorflow.python.framework import ops
# Include attrs that start with single underscore.
_HAS_DYNAMIC_ATTRIBUTES = True
dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')]
from tensorflow_estimator.contrib.estimator.python.estimator.dnn import * @ops.RegisterGradient("XlaClusterOutput")
def _XlaClusterOutputGrad(_, grad):
del grad # unused
raise RuntimeError("Gradient computation of graph in xla.compile() is "
"prohibited because it can cause performance degradation."
"Please move gradient computation inside xla.compile().")

View File

@ -26,6 +26,10 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace { namespace {
bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
namespace reduce_device_to_host_copies {
Status FindNodesToDecluster(const Graph& graph, Status FindNodesToDecluster(const Graph& graph,
absl::flat_hash_set<Node*>* result, absl::flat_hash_set<Node*>* result,
absl::Span<Node* const> post_order) { absl::Span<Node* const> post_order) {
@ -140,8 +144,6 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK(); return Status::OK();
} }
bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
// Clones nodes to outside their cluster to avoid device-to-host copies. For // Clones nodes to outside their cluster to avoid device-to-host copies. For
// instance, converts this: // instance, converts this:
// //
@ -168,7 +170,7 @@ bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
// where the ===> arrow has a hostmem source and destination and would entail a // where the ===> arrow has a hostmem source and destination and would entail a
// device to host copy if the source and destination were not in the same XLA // device to host copy if the source and destination were not in the same XLA
// cluster. // cluster.
Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { Status PartiallyDeclusterGraph(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision // When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too. // on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been // Iterating the graph in post-order guarantees that consumers have been
@ -206,7 +208,9 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
return Status::OK(); return Status::OK();
} }
} // namespace reduce_device_to_host_copies
namespace reduce_recompilation {
bool IsIntraClusterEdge(const Edge& edge) { bool IsIntraClusterEdge(const Edge& edge) {
absl::optional<absl::string_view> src_cluster_name = absl::optional<absl::string_view> src_cluster_name =
GetXlaClusterForNode(*edge.src()); GetXlaClusterForNode(*edge.src());
@ -269,7 +273,7 @@ Status MustCompileNode(const Node* n, bool* must_compile) {
// regress performance in any significant manner. We will have to revisit this // regress performance in any significant manner. We will have to revisit this
// algorith with a more complex cost model if this assumption turns out to be // algorith with a more complex cost model if this assumption turns out to be
// incorrect. // incorrect.
Status DeclusterNodesToReduceRecompilations(Graph* graph) { Status PartiallyDeclusterGraph(Graph* graph) {
std::vector<bool> compile_time_const_nodes(graph->num_node_ids()); std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
TF_RETURN_IF_ERROR(BackwardsConstAnalysis( TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
@ -322,7 +326,7 @@ Status DeclusterNodesToReduceRecompilations(Graph* graph) {
return Status::OK(); return Status::OK();
} }
} // namespace reduce_recompilation
} // namespace } // namespace
Status PartiallyDeclusterPass::Run( Status PartiallyDeclusterPass::Run(
@ -334,8 +338,9 @@ Status PartiallyDeclusterPass::Run(
Graph* graph = options.graph->get(); Graph* graph = options.graph->get();
TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); TF_RETURN_IF_ERROR(
TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(graph));
return Status::OK(); return Status::OK();
} }

View File

@ -386,7 +386,7 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(s.ToGraph(graph.get())); TF_ASSERT_OK(s.ToGraph(graph.get()));
// This is needed to register the XLA_GPU device. // This is needed to register the XLA_GPU device.
std::vector<Device*> devices; std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::AddDevices( TF_ASSERT_OK(DeviceFactory::AddDevices(
SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
@ -400,10 +400,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
TF_ASSERT_OK(PartiallyDecluster(&graph)); TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
for (Device* d : devices) {
delete d;
}
} }
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) { TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {

View File

@ -17,8 +17,8 @@ limitations under the License.
// operators using XLA via the XLA "Host" (CPU) backend. // operators using XLA via the XLA "Host" (CPU) backend.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/jit/xla_device_ops.h"
@ -31,13 +31,13 @@ namespace tensorflow {
class XlaCpuDeviceFactory : public DeviceFactory { class XlaCpuDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
}; };
Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options, Status XlaCpuDeviceFactory::CreateDevices(
const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
legacy_flags::XlaDeviceFlags* flags = legacy_flags::GetXlaDeviceFlags(); XlaDeviceFlags* flags = GetXlaDeviceFlags();
bool compile_on_demand = flags->tf_xla_compile_on_demand; bool compile_on_demand = flags->tf_xla_compile_on_demand;
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
@ -63,8 +63,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
options.device_ordinal = 0; options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT; options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false; options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options); devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
devices->push_back(device.release());
return Status::OK(); return Status::OK();
} }

View File

@ -218,6 +218,9 @@ XlaDevice::XlaDevice(const SessionOptions& session_options,
XlaDevice::~XlaDevice() { XlaDevice::~XlaDevice() {
VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this; VLOG(1) << "Destroying XLA device " << jit_device_name_ << " " << this;
mutex_lock lock(mu_); mutex_lock lock(mu_);
while (outstanding_asynchronous_operations_ > 0) {
outstanding_asynchronous_operations_cv_.wait(lock);
}
if (device_context_) { if (device_context_) {
device_context_->Unref(); device_context_->Unref();
} }
@ -384,6 +387,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
Status XlaDevice::Sync() { Status XlaDevice::Sync() {
VLOG(1) << "XlaDevice::Sync"; VLOG(1) << "XlaDevice::Sync";
tracing::ScopedActivity activity("XlaDevice::Sync", /*is_expensive=*/true);
std::shared_ptr<se::Stream> stream; std::shared_ptr<se::Stream> stream;
{ {
mutex_lock lock(mu_); mutex_lock lock(mu_);
@ -391,13 +395,46 @@ Status XlaDevice::Sync() {
} }
if (!stream) return Status::OK(); if (!stream) return Status::OK();
if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { Status status = stream->BlockHostUntilDone();
{
mutex_lock lock(mu_);
while (outstanding_asynchronous_operations_ > 0) {
outstanding_asynchronous_operations_cv_.wait(lock);
}
}
TF_RETURN_IF_ERROR(status);
if (!stream->ok()) {
return errors::Internal("XlaDevice::Sync() failed."); return errors::Internal("XlaDevice::Sync() failed.");
} }
VLOG(1) << "XlaDevice::Sync completed"; VLOG(1) << "XlaDevice::Sync completed";
return Status::OK(); return Status::OK();
} }
void XlaDevice::Sync(const DoneCallback& done) {
VLOG(1) << "XlaDevice::Sync (asynchronous)";
std::shared_ptr<se::Stream> stream;
{
mutex_lock lock(mu_);
stream = stream_;
}
if (!stream) {
done(Status::OK());
return;
}
stream->ThenEnqueueOnBackgroundThread(
[this, stream, done](se::StreamExecutor*) {
tracing::ScopedActivity activity("XlaDevice::Sync::Callback",
/*is_expensive=*/true);
mutex_lock lock(mu_);
while (outstanding_asynchronous_operations_ > 0) {
outstanding_asynchronous_operations_cv_.wait(lock);
}
done(stream->ok() ? Status::OK()
: errors::Internal("XlaDevice::Sync() failed."));
});
}
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs, const AllocatorAttributes alloc_attrs,
Tensor* tensor) { Tensor* tensor) {
@ -441,6 +478,49 @@ bool XlaDevice::RequiresSyncOnCompletion() const {
return sync_on_completion_; return sync_on_completion_;
} }
XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
XlaDevice* device)
: device_(device) {
mutex_lock lock(device_->mu_);
++device_->outstanding_asynchronous_operations_;
}
XlaDevice::AsynchronousOperationHandle::~AsynchronousOperationHandle() {
if (device_) {
mutex_lock lock(device_->mu_);
--device_->outstanding_asynchronous_operations_;
device_->outstanding_asynchronous_operations_cv_.notify_all();
}
}
XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
const XlaDevice::AsynchronousOperationHandle& other)
: device_(other.device_) {
mutex_lock lock(device_->mu_);
++device_->outstanding_asynchronous_operations_;
}
XlaDevice::AsynchronousOperationHandle::AsynchronousOperationHandle(
XlaDevice::AsynchronousOperationHandle&& other)
: device_(other.device_) {
other.device_ = nullptr;
}
XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle::
operator=(const XlaDevice::AsynchronousOperationHandle& other) {
device_ = other.device_;
mutex_lock lock(device_->mu_);
++device_->outstanding_asynchronous_operations_;
return *this;
}
XlaDevice::AsynchronousOperationHandle& XlaDevice::AsynchronousOperationHandle::
operator=(XlaDevice::AsynchronousOperationHandle&& other) {
device_ = other.device_;
other.device_ = nullptr;
return *this;
}
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) { const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter // Any op assigned to the device that isn't rewritten by the graph rewriter

View File

@ -135,6 +135,7 @@ class XlaDevice : public LocalDevice {
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override; AsyncOpKernel::DoneCallback done) override;
Status Sync() override; Status Sync() override;
void Sync(const DoneCallback& done) override;
Status FillContextMap(const Graph* graph, Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override DeviceContextMap* device_context_map) override
@ -164,7 +165,30 @@ class XlaDevice : public LocalDevice {
bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
// A simple RAII handle. On construction the device's
// outstanding_asynchronous_operations_ field is incremented; on destruction
// it is decremented.
class AsynchronousOperationHandle {
public:
AsynchronousOperationHandle(XlaDevice* device);
~AsynchronousOperationHandle();
AsynchronousOperationHandle(const AsynchronousOperationHandle& other);
AsynchronousOperationHandle(AsynchronousOperationHandle&& other);
AsynchronousOperationHandle& operator=(
const AsynchronousOperationHandle& other);
AsynchronousOperationHandle& operator=(AsynchronousOperationHandle&& other);
private:
XlaDevice* device_ = nullptr;
};
AsynchronousOperationHandle CreateAsynchronousOperationHandle() {
return AsynchronousOperationHandle(this);
}
private: private:
friend class AsynchronousOperationHandle;
xla::LocalClient* client() const; xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr) Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_); EXCLUSIVE_LOCKS_REQUIRED(mu_);
@ -227,6 +251,11 @@ class XlaDevice : public LocalDevice {
// True if the device requires XlaDevice::Sync to be called on completion // True if the device requires XlaDevice::Sync to be called on completion
// regardless of status. // regardless of status.
bool sync_on_completion_ GUARDED_BY(mu_) = false; bool sync_on_completion_ GUARDED_BY(mu_) = false;
// Count of outstanding asynchronous operations which must be zero on Sync()
// completion.
int64 outstanding_asynchronous_operations_ GUARDED_BY(mu_) = 0;
condition_variable outstanding_asynchronous_operations_cv_;
}; };
// Builds OpKernel registrations on 'device' for the JIT operators // Builds OpKernel registrations on 'device' for the JIT operators

View File

@ -29,12 +29,12 @@ namespace tensorflow {
class XlaGpuDeviceFactory : public DeviceFactory { class XlaGpuDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
}; };
Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options, Status XlaGpuDeviceFactory::CreateDevices(
const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
XlaOpRegistry::DeviceRegistration registration; XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT; registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy = registration.autoclustering_policy =
@ -70,7 +70,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
return status; return status;
} }
devices->push_back(device.release()); devices->push_back(std::move(device));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -33,12 +33,12 @@ constexpr std::array<DataType, 9> kExecAllTypes = {
class XlaInterpreterDeviceFactory : public DeviceFactory { class XlaInterpreterDeviceFactory : public DeviceFactory {
public: public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix, Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<Device*>* devices) override; std::vector<std::unique_ptr<Device>>* devices) override;
}; };
Status XlaInterpreterDeviceFactory::CreateDevices( Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix, const SessionOptions& session_options, const string& name_prefix,
std::vector<Device*>* devices) { std::vector<std::unique_ptr<Device>>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations; (void)registrations;
@ -61,8 +61,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
options.device_ordinal = 0; options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false; options.use_multiple_streams = false;
auto device = absl::make_unique<XlaDevice>(session_options, options); devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
devices->push_back(device.release());
return Status::OK(); return Status::OK();
} }

View File

@ -375,27 +375,6 @@ tf_xla_py_test(
], ],
) )
tf_xla_py_test(
name = "resampler_ops_test",
size = "small",
srcs = ["resampler_ops_test.py"],
disabled_backends = [
# TODO(b/74459949) Support BatchDot in CPU backend.
"cpu",
"cpu_ondemand",
],
# TODO(b/112295522): figure out how to make OSS build pass.
tags = ["no_oss"],
deps = [
":xla_test",
"//tensorflow/contrib/resampler:resampler_ops",
"//tensorflow/contrib/resampler:resampler_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test( tf_xla_py_test(
name = "dynamic_stitch_test", name = "dynamic_stitch_test",
size = "small", size = "small",
@ -474,7 +453,6 @@ tf_xla_py_test(
"//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops",
"//tensorflow/python/ops/signal", "//tensorflow/python/ops/signal",
], ],
) )

View File

@ -50,8 +50,8 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step) zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllClose([0.0, 0.0], var0.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var0))
self.assertAllClose([0.0, 0.0], var1.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run a step of AdagradDA # Run a step of AdagradDA
update.run() update.run()
@ -63,9 +63,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
# For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
# similarly for others. # similarly for others.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.904534, -1.603567]), var0.eval()) np.array([-0.904534, -1.603567]), self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.094821, -0.189358]), var1.eval()) np.array([-0.094821, -0.189358]), self.evaluate(var1))
def testAdagradDAwithoutRegularizationBasic2(self): def testAdagradDAwithoutRegularizationBasic2(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -87,16 +87,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step) zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run a step of AdagradDA # Run a step of AdagradDA
update.run() update.run()
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.904534, -1.603567]), var0.eval()) np.array([-0.904534, -1.603567]), self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.094821, -0.189358]), var1.eval()) np.array([-0.094821, -0.189358]), self.evaluate(var1))
def testAdagradDAWithL1(self): def testAdagradDAWithL1(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -118,16 +118,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step) zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run a step of AdagradDA # Run a step of AdagradDA
update.run() update.run()
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.895489, -1.59555]), var0.eval()) np.array([-0.895489, -1.59555]), self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.085339, -0.17989]), var1.eval()) np.array([-0.085339, -0.17989]), self.evaluate(var1))
def testAdagradDAWithL1_L2(self): def testAdagradDAWithL1_L2(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -149,16 +149,16 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1]), global_step=global_step) zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run a step of AdagradDA # Run a step of AdagradDA
update.run() update.run()
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.046907, -0.093659]), var0.eval()) np.array([-0.046907, -0.093659]), self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.004275, -0.009023]), var1.eval()) np.array([-0.004275, -0.009023]), self.evaluate(var1))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -42,17 +42,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1])) zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of adagrad # Run 3 steps of adagrad
for _ in range(3): for _ in range(3):
ada_update.run() ada_update.run()
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), np.array([-1.6026098728179932, -0.6026098728179932]),
self.evaluate(var0),
float_rtol=1e-5) float_rtol=1e-5)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval(), np.array([2.715679168701172, 3.715679168701172]),
self.evaluate(var1),
float_rtol=1e-5) float_rtol=1e-5)
def testTensorLearningRate(self): def testTensorLearningRate(self):
@ -68,17 +70,19 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
zip([grads0, grads1], [var0, var1])) zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of adagrad # Run 3 steps of adagrad
for _ in range(3): for _ in range(3):
ada_update.run() ada_update.run()
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), np.array([-1.6026098728179932, -0.6026098728179932]),
self.evaluate(var0),
float_rtol=1e-5) float_rtol=1e-5)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval(), np.array([2.715679168701172, 3.715679168701172]),
self.evaluate(var1),
float_rtol=1e-5) float_rtol=1e-5)
def testSharing(self): def testSharing(self):
@ -103,18 +107,20 @@ class AdagradOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values. # Fetch params to validate initial values.
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Mix the first and the second adagrad for 3 steps. # Mix the first and the second adagrad for 3 steps.
ada_update1.run() ada_update1.run()
ada_update2.run() ada_update2.run()
ada_update1.run() ada_update1.run()
# Validate updated params (the same as with only 1 Adagrad). # Validate updated params (the same as with only 1 Adagrad).
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval(), np.array([-1.6026098728179932, -0.6026098728179932]),
self.evaluate(var0),
float_rtol=1e-5) float_rtol=1e-5)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval(), np.array([2.715679168701172, 3.715679168701172]),
self.evaluate(var1),
float_rtol=1e-5) float_rtol=1e-5)

View File

@ -75,23 +75,24 @@ class AdamOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power, beta2_power = opt._get_beta_accumulators() beta1_power, beta2_power = opt._get_beta_accumulators()
# Run 3 steps of Adam # Run 3 steps of Adam
for t in range(1, 4): for t in range(1, 4):
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) self.assertAllCloseAccordingToType(0.999**t,
self.evaluate(beta2_power))
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, var1.eval()) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testTensorLearningRate(self): def testTensorLearningRate(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -117,23 +118,24 @@ class AdamOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power, beta2_power = opt._get_beta_accumulators() beta1_power, beta2_power = opt._get_beta_accumulators()
# Run 3 steps of Adam # Run 3 steps of Adam
for t in range(1, 4): for t in range(1, 4):
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) self.assertAllCloseAccordingToType(0.999**t,
self.evaluate(beta2_power))
update.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) update.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, var1.eval()) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testSharing(self): def testSharing(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -162,13 +164,14 @@ class AdamOptimizerTest(xla_test.XLATestCase):
beta1_power, beta2_power = opt._get_beta_accumulators() beta1_power, beta2_power = opt._get_beta_accumulators()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of intertwined Adam1 and Adam2. # Run 3 steps of intertwined Adam1 and Adam2.
for t in range(1, 4): for t in range(1, 4):
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) self.assertAllCloseAccordingToType(0.999**t,
self.evaluate(beta2_power))
if t % 2 == 0: if t % 2 == 0:
update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np}) update1.run(feed_dict={grads0: grads0_np, grads1: grads1_np})
else: else:
@ -178,8 +181,8 @@ class AdamOptimizerTest(xla_test.XLATestCase):
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, var1.eval()) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -78,8 +78,8 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power = opt._get_beta_accumulators() beta1_power = opt._get_beta_accumulators()
@ -87,14 +87,17 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
for t in range(1, 4): for t in range(1, 4):
update.run() update.run()
self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval()) self.assertAllCloseAccordingToType(0.9**(t + 1),
self.evaluate(beta1_power))
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2) self.assertAllCloseAccordingToType(
self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2) var0_np, self.evaluate(var0), rtol=1e-2)
self.assertAllCloseAccordingToType(
var1_np, self.evaluate(var1), rtol=1e-2)
self.assertEqual("var0_%d/AdaMax:0" % (i,), self.assertEqual("var0_%d/AdaMax:0" % (i,),
opt.get_slot(var=var0, name="m").name) opt.get_slot(var=var0, name="m").name)
@ -118,22 +121,23 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
beta1_power = opt._get_beta_accumulators() beta1_power = opt._get_beta_accumulators()
# Run 3 steps of AdaMax # Run 3 steps of AdaMax
for t in range(1, 4): for t in range(1, 4):
self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
update.run() update.run()
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0) var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, var1.eval()) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -90,8 +90,8 @@ class AddSignTest(xla_test.XLATestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 7 steps of AddSign # Run 7 steps of AddSign
# first 4 steps with positive gradient # first 4 steps with positive gradient
@ -125,8 +125,8 @@ class AddSignTest(xla_test.XLATestCase):
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
var0_np, var0.eval(), half_rtol=1e-2) var0_np, self.evaluate(var0), half_rtol=1e-2)
self.assertAllCloseAccordingToType(var1_np, var1.eval()) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testDense(self): def testDense(self):
decay_steps = 10 decay_steps = 10

View File

@ -218,6 +218,21 @@ class BinaryOpsTest(xla_test.XLATestCase):
], ],
equality_test=self.ListsAreClose) equality_test=self.ListsAreClose)
# TF doesn't define these for bf16.
if dtype != dtypes.bfloat16.as_numpy_dtype:
self._testBinary(
gen_math_ops.xdivy,
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype))
self._testBinary(
gen_math_ops.xlogy,
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0],
dtype=dtype))
def testIntOps(self): def testIntOps(self):
for dtype in self.signed_int_types: for dtype in self.signed_int_types:
self._testBinary( self._testBinary(

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
@ -56,11 +57,11 @@ class CategoricalTest(xla_test.XLATestCase):
Returns: Returns:
Frequencies from sampled classes; shape [batch_size, num_classes]. Frequencies from sampled classes; shape [batch_size, num_classes].
""" """
with self.cached_session() as sess, self.test_scope(): with self.cached_session(), self.test_scope():
random_seed.set_random_seed(1618) random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples, op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32) output_dtype=dtypes.int32)
d = sess.run(op) d = self.evaluate(op)
batch_size, num_classes = logits.shape batch_size, num_classes = logits.shape
freqs_mat = [] freqs_mat = []
@ -79,15 +80,15 @@ class CategoricalTest(xla_test.XLATestCase):
def _testRngIsNotConstant(self, rng, dtype, output_dtype): def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value. # Tests that 'rng' does not always return the same value.
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
x = rng(dtype, output_dtype) x = rng(dtype, output_dtype)
# The random-number generator, if working correctly, should produce the # The random-number generator, if working correctly, should produce the
# same output multiple times with low probability. # same output multiple times with low probability.
y = sess.run(x) y = self.evaluate(x)
z = sess.run(x) z = self.evaluate(x)
w = sess.run(x) w = self.evaluate(x)
# We use exact equality here. If the random-number generator is producing # We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical. # deterministic output, all three outputs will be bitwise identical.
@ -107,12 +108,12 @@ class CategoricalTest(xla_test.XLATestCase):
def testCategoricalIsInRange(self): def testCategoricalIsInRange(self):
for dtype in self.float_types: for dtype in self.float_types:
for output_dtype in self.output_dtypes(): for output_dtype in self.output_dtypes():
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
x = random_ops.multinomial( x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000, array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
output_dtype=output_dtype) output_dtype=output_dtype)
y = sess.run(x) y = self.evaluate(x)
self.assertTrue((y >= 0).sum() == 1000) self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000) self.assertTrue((y < 20).sum() == 1000)
@ -138,6 +139,57 @@ class CategoricalTest(xla_test.XLATestCase):
chi2 = self._chi2(probs, freqs) chi2 = self._chi2(probs, freqs)
self.assertLess(chi2, 1e-3) self.assertLess(chi2, 1e-3)
def testStatelessMultinomialIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
with self.cached_session() as sess:
with self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless_random_ops.stateless_multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype),
1000,
seed_t,
output_dtype=output_dtype)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertTrue((y >= 0).sum() == 1000)
self.assertTrue((y < 20).sum() == 1000)
def testDeterminismMultinomial(self):
# Stateless values should be equal iff the seeds are equal (roughly)
num_samples = 10
with self.cached_session(), self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
[0.25, 0.75]]):
pure = stateless_random_ops.stateless_multinomial(
logits, num_samples, seed=seed_t)
values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
for s0, v0 in values:
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
def testEmpty(self):
with self.cached_session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
y = self.evaluate(x)
self.assertEqual(y.shape, (42, 0))
def testEmptyStateless(self):
with self.cached_session() as sess:
with self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
x = stateless_random_ops.stateless_multinomial(
array_ops.zeros([42, 40]),
0,
seed=seed_t,
output_dtype=dtypes.int32)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertEqual(y.shape, (42, 0))
if __name__ == '__main__': if __name__ == '__main__':
googletest.main() googletest.main()

View File

@ -43,7 +43,7 @@ class ClusteringTest(xla_test.XLATestCase):
input1 = constant_op.constant(val1, name="const1") input1 = constant_op.constant(val1, name="const1")
input2 = constant_op.constant(val2, name="const2") input2 = constant_op.constant(val2, name="const2")
output = math_ops.add(input1, input2) output = math_ops.add(input1, input2)
result = output.eval() result = self.evaluate(output)
self.assertAllClose(result, expected, rtol=1e-3) self.assertAllClose(result, expected, rtol=1e-3)
def testAddFromCpuMultiple(self): def testAddFromCpuMultiple(self):
@ -57,7 +57,7 @@ class ClusteringTest(xla_test.XLATestCase):
with self.test_scope(): with self.test_scope():
output = math_ops.add(input1, input2) output = math_ops.add(input1, input2)
for _ in xrange(10): for _ in xrange(10):
result = output.eval() result = self.evaluate(output)
self.assertAllClose(result, expected, rtol=1e-3) self.assertAllClose(result, expected, rtol=1e-3)
def testDeadlock(self): def testDeadlock(self):

View File

@ -72,7 +72,7 @@ class ConcatTest(xla_test.XLATestCase):
x2 = constant_op.constant(p2) x2 = constant_op.constant(p2)
with self.test_scope(): with self.test_scope():
c = array_ops.concat([x1, x2], 0) c = array_ops.concat([x1, x2], 0)
result = c.eval() result = self.evaluate(c)
self.assertAllEqual(result[:2, :], p1) self.assertAllEqual(result[:2, :], p1)
self.assertAllEqual(result[2:, :], p2) self.assertAllEqual(result[2:, :], p2)
@ -150,7 +150,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape) [float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 1) concated_grad = array_ops.concat(grad, 1)
result = concated_grad.eval() result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp) self.assertAllEqual(result, grad_inp)
def testGradientsSimpleAll(self): def testGradientsSimpleAll(self):
@ -177,7 +177,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape) [float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 0) concated_grad = array_ops.concat(grad, 0)
result = concated_grad.eval() result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp) self.assertAllEqual(result, grad_inp)
@ -205,7 +205,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape) [float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, 2) concated_grad = array_ops.concat(grad, 2)
result = concated_grad.eval() result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp) self.assertAllEqual(result, grad_inp)
@ -242,7 +242,7 @@ class ConcatTest(xla_test.XLATestCase):
[float(x) for x in grad_inp.flatten()], shape=output_shape) [float(x) for x in grad_inp.flatten()], shape=output_shape)
grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor]) grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
concated_grad = array_ops.concat(grad, concat_dim) concated_grad = array_ops.concat(grad, concat_dim)
result = concated_grad.eval() result = self.evaluate(concated_grad)
self.assertAllEqual(result, grad_inp) self.assertAllEqual(result, grad_inp)
@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase):
def DISABLED_testZeroSize(self): def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs # Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7) np.random.seed(7)
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
for shape0 in (), (2,): for shape0 in (), (2,):
axis = len(shape0) axis = len(shape0)
@ -270,7 +270,7 @@ class ConcatTest(xla_test.XLATestCase):
self.assertAllEqual(c.eval(), correct) self.assertAllEqual(c.eval(), correct)
# Check gradients # Check gradients
dc = np.random.randn(*c.get_shape().as_list()) dc = np.random.randn(*c.get_shape().as_list())
dxs = sess.run(gradients_impl.gradients(c, xs, dc)) dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis)) self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testConcatTuple(self): def testConcatTuple(self):
@ -280,7 +280,7 @@ class ConcatTest(xla_test.XLATestCase):
with self.test_scope(): with self.test_scope():
concat_list_t = array_ops.concat([c1, c2], 0) concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0) concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t))
def testConcatNoScalars(self): def testConcatNoScalars(self):
with self.cached_session(): with self.cached_session():
@ -330,47 +330,47 @@ class ConcatTest(xla_test.XLATestCase):
class ConcatOffsetTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32) cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32)
off = gen_array_ops.concat_offset(cdim, [s0, s1, s2]) off = gen_array_ops.concat_offset(cdim, [s0, s1, s2])
ans = sess.run(off) ans = self.evaluate(off)
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
class PackTest(xla_test.XLATestCase): class PackTest(xla_test.XLATestCase):
def testBasic(self): def testBasic(self):
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32)
s2 = constant_op.constant([2, 20, 5], dtypes.int32) s2 = constant_op.constant([2, 20, 5], dtypes.int32)
packed = array_ops.stack([s0, s1, s2]) packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed) ans = self.evaluate(packed)
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self): def testScalars(self):
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32) s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32)
s2 = constant_op.constant(5, dtypes.int32) s2 = constant_op.constant(5, dtypes.int32)
packed = array_ops.stack([s0, s1, s2]) packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed) ans = self.evaluate(packed)
self.assertAllEqual(ans, [2, 3, 5]) self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self): def testEmpty(self):
with self.cached_session() as sess: with self.cached_session():
with self.test_scope(): with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32) s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32)
s2 = constant_op.constant([[]], dtypes.int32) s2 = constant_op.constant([[]], dtypes.int32)
packed = array_ops.stack([s0, s1, s2]) packed = array_ops.stack([s0, s1, s2])
ans = sess.run(packed) ans = self.evaluate(packed)
self.assertAllEqual(ans, [[[]], [[]], [[]]]) self.assertAllEqual(ans, [[[]], [[]], [[]]])

View File

@ -85,7 +85,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose( output = nn_ops.conv3d_transpose(
x, f, y_shape, strides=strides, padding="SAME") x, f, y_shape, strides=strides, padding="SAME")
value = output.eval() value = self.evaluate(output)
# We count the number of cells being added at the locations in the output. # We count the number of cells being added at the locations in the output.
# At the center, #cells = kernel_depth * kernel_height * kernel_width # At the center, #cells = kernel_depth * kernel_height * kernel_width
@ -135,7 +135,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose( output = nn_ops.conv3d_transpose(
x, f, y_shape, strides=strides, padding="SAME") x, f, y_shape, strides=strides, padding="SAME")
value = output.eval() value = self.evaluate(output)
for n in xrange(x_shape[0]): for n in xrange(x_shape[0]):
for k in xrange(f_shape[3]): for k in xrange(f_shape[3]):
@ -173,7 +173,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase):
1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose( output = nn_ops.conv3d_transpose(
x, f, y_shape, strides=strides, padding="VALID") x, f, y_shape, strides=strides, padding="VALID")
value = output.eval() value = self.evaluate(output)
cache_values = np.zeros(y_shape, dtype=np.float32) cache_values = np.zeros(y_shape, dtype=np.float32)

View File

@ -42,7 +42,7 @@ def GetRunMetadataLabels(run_metadata):
def InLabels(labels, substr): def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr.""" """Returns true iff one of the labels contains substr."""
return any([substr in x for x in labels]) return any(substr in x for x in labels)
class DenseLayerTest(test.TestCase): class DenseLayerTest(test.TestCase):
@ -72,7 +72,7 @@ class DenseLayerTest(test.TestCase):
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
y = layers.dense(x, 3) y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables()) self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata() run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup( test_utils.RunWithWarmup(
sess, sess,
@ -97,7 +97,7 @@ class DenseLayerTest(test.TestCase):
with jit_scope(): with jit_scope():
y = layers.dense(x, 3) y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables()) self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata() run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup( test_utils.RunWithWarmup(
sess, sess,
@ -126,7 +126,7 @@ class DenseLayerTest(test.TestCase):
with jit_scope(): with jit_scope():
y = layers.dense(x, 3) y = layers.dense(x, 3)
sess.run(variables.initialize_all_variables()) self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata() run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup( test_utils.RunWithWarmup(
sess, sess,

View File

@ -58,6 +58,15 @@ class DynamicStitchTest(xla_test.XLATestCase):
[idx1, idx2], [val1, val2], [idx1, idx2], [val1, val2],
expected=np.array([[], [], [], []], np.int32)) expected=np.array([[], [], [], []], np.int32))
def testEmptyIndex(self):
idx1 = np.array([], dtype=np.int32)
idx2 = np.array([[], []], dtype=np.int32)
val1 = np.ndarray(shape=(0, 9), dtype=np.int32)
val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32)
self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2],
expected=np.ndarray(
shape=(0, 9), dtype=np.int32))
def testSimple1D(self): def testSimple1D(self):
val1 = np.array([0, 4, 7], dtype=np.int32) val1 = np.array([0, 4, 7], dtype=np.int32)
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32) val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)

View File

@ -101,12 +101,12 @@ class EagerTest(xla_test.XLATestCase):
self.assertAllEqual(15, product) self.assertAllEqual(15, product)
# Run some ops graphly # Run some ops graphly
with context.graph_mode(), self.cached_session() as sess: with context.graph_mode(), self.cached_session():
with self.test_scope(): with self.test_scope():
three = constant_op.constant(3) three = constant_op.constant(3)
five = constant_op.constant(5) five = constant_op.constant(5)
product = three * five product = three * five
self.assertAllEqual(15, sess.run(product)) self.assertAllEqual(15, self.evaluate(product))
def testDegenerateSlices(self): def testDegenerateSlices(self):
with self.test_scope(): with self.test_scope():

View File

@ -27,8 +27,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import signal from tensorflow.python.ops.signal import signal
from tensorflow.python.ops import spectral_ops
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
BATCH_DIMS = (3, 5) BATCH_DIMS = (3, 5)
@ -107,39 +106,39 @@ class FFTTest(xla_test.XLATestCase):
def testFFT(self): def testFFT(self):
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
spectral_ops.fft) signal.fft)
def testFFT2D(self): def testFFT2D(self):
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2,
spectral_ops.fft2d) signal.fft2d)
def testFFT3D(self): def testFFT3D(self):
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), lambda x: np.fft.fftn(x, axes=(-3, -2, -1)),
spectral_ops.fft3d) signal.fft3d)
def testIFFT(self): def testIFFT(self):
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft,
spectral_ops.ifft) signal.ifft)
def testIFFT2D(self): def testIFFT2D(self):
self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2,
spectral_ops.ifft2d) signal.ifft2d)
def testIFFT3D(self): def testIFFT3D(self):
self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x,
lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)),
spectral_ops.ifft3d) signal.ifft3d)
def testRFFT(self): def testRFFT(self):
self._VerifyFftMethod( self._VerifyFftMethod(
INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]),
lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) lambda x: signal.rfft(x, fft_length=[x.shape[-1].value]))
def testRFFT2D(self): def testRFFT2D(self):
def _tf_fn(x): def _tf_fn(x):
return spectral_ops.rfft2d( return signal.rfft2d(
x, fft_length=[x.shape[-2].value, x.shape[-1].value]) x, fft_length=[x.shape[-2].value, x.shape[-1].value])
self._VerifyFftMethod( self._VerifyFftMethod(
@ -153,16 +152,33 @@ class FFTTest(xla_test.XLATestCase):
x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]])
def _tf_fn(x): def _tf_fn(x):
return spectral_ops.rfft3d( return signal.rfft3d(
x, x,
fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
def testRFFT3DMismatchedSize(self):
def _to_expected(x):
return np.fft.rfftn(
x,
axes=(-3, -2, -1),
s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
def _tf_fn(x):
return signal.rfft3d(
x,
fft_length=[
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
])
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
def testIRFFT(self): def testIRFFT(self):
def _tf_fn(x): def _tf_fn(x):
return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) return signal.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)])
self._VerifyFftMethod( self._VerifyFftMethod(
INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]),
@ -171,7 +187,7 @@ class FFTTest(xla_test.XLATestCase):
def testIRFFT2D(self): def testIRFFT2D(self):
def _tf_fn(x): def _tf_fn(x):
return spectral_ops.irfft2d( return signal.irfft2d(
x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)])
self._VerifyFftMethod( self._VerifyFftMethod(
@ -195,7 +211,7 @@ class FFTTest(xla_test.XLATestCase):
s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)])
def _tf_fn(x): def _tf_fn(x):
return spectral_ops.irfft3d( return signal.irfft3d(
x, x,
fft_length=[ fft_length=[
x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1)
@ -203,6 +219,30 @@ class FFTTest(xla_test.XLATestCase):
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
def testIRFFT3DMismatchedSize(self):
def _to_input(x):
return np.fft.rfftn(
np.real(x),
axes=(-3, -2, -1),
s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
def _to_expected(x):
return np.fft.irfftn(
x,
axes=(-3, -2, -1),
s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
def _tf_fn(x):
return signal.irfft3d(
x,
fft_length=[
x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
])
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
if __name__ == "__main__": if __name__ == "__main__":
googletest.main() googletest.main()

View File

@ -129,7 +129,7 @@ class FIFOQueueTest(xla_test.XLATestCase):
enqueue_op.run() enqueue_op.run()
for i in xrange(len(elems)): for i in xrange(len(elems)):
vals = dequeued_t.eval() vals = self.evaluate(dequeued_t)
self.assertEqual([elems[i]], vals) self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self): def testEnqueueAndBlockingDequeue(self):
@ -192,9 +192,9 @@ class FIFOQueueTest(xla_test.XLATestCase):
self.assertEqual([], size.get_shape()) self.assertEqual([], size.get_shape())
enqueue_op.run() enqueue_op.run()
self.assertEqual(1, size.eval()) self.assertEqual(1, self.evaluate(size))
dequeued_t.op.run() dequeued_t.op.run()
self.assertEqual(0, size.eval()) self.assertEqual(0, self.evaluate(size))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -50,14 +50,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], var0.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var0))
self.assertAllClose([0.0, 0.0], var1.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run Ftrl for a few steps # Run Ftrl for a few steps
for _ in range(steps): for _ in range(steps):
ftrl_update.run() ftrl_update.run()
return var0.eval(), var1.eval() return self.evaluate(var0), self.evaluate(var1)
def equivAdagradTest_AdagradPart(self, steps, dtype): def equivAdagradTest_AdagradPart(self, steps, dtype):
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
@ -65,14 +65,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) adagrad_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], var0.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var0))
self.assertAllClose([0.0, 0.0], var1.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run Adagrad for a few steps # Run Adagrad for a few steps
for _ in range(steps): for _ in range(steps):
adagrad_update.run() adagrad_update.run()
return var0.eval(), var1.eval() return self.evaluate(var0), self.evaluate(var1)
def equivGradientDescentTest_FtrlPart(self, steps, dtype): def equivGradientDescentTest_FtrlPart(self, steps, dtype):
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
@ -85,14 +85,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], var0.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var0))
self.assertAllClose([0.0, 0.0], var1.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run Ftrl for a few steps # Run Ftrl for a few steps
for _ in range(steps): for _ in range(steps):
ftrl_update.run() ftrl_update.run()
return var0.eval(), var1.eval() return self.evaluate(var0), self.evaluate(var1)
def equivGradientDescentTest_GradientDescentPart(self, steps, dtype): def equivGradientDescentTest_GradientDescentPart(self, steps, dtype):
var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype) var0, var1, grads0, grads1 = self.initVariableAndGradient(dtype)
@ -100,14 +100,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) sgd_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], var0.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var0))
self.assertAllClose([0.0, 0.0], var1.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run GradientDescent for a few steps # Run GradientDescent for a few steps
for _ in range(steps): for _ in range(steps):
sgd_update.run() sgd_update.run()
return var0.eval(), var1.eval() return self.evaluate(var0), self.evaluate(var1)
def testFtrlwithoutRegularization(self): def testFtrlwithoutRegularization(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -124,8 +124,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([0.0, 0.0], var0.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var0))
self.assertAllClose([0.0, 0.0], var1.eval()) self.assertAllClose([0.0, 0.0], self.evaluate(var1))
# Run 3 steps FTRL # Run 3 steps FTRL
for _ in range(3): for _ in range(3):
@ -134,12 +134,12 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-2.60260963, -4.29698515]), np.array([-2.60260963, -4.29698515]),
var0.eval(), self.evaluate(var0),
float_rtol=1e-4, float_rtol=1e-4,
half_rtol=1e-2) half_rtol=1e-2)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.28432083, -0.56694895]), np.array([-0.28432083, -0.56694895]),
var1.eval(), self.evaluate(var1),
float_rtol=1e-5, float_rtol=1e-5,
half_rtol=1e-2) half_rtol=1e-2)
@ -158,8 +158,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([4.0, 3.0], var1.eval()) self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 3 steps FTRL # Run 3 steps FTRL
for _ in range(3): for _ in range(3):
@ -167,10 +167,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5, np.array([-2.55607247, -3.98729396]),
self.evaluate(var0),
1e-5,
1e-5,
float_rtol=1e-4) float_rtol=1e-4)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) np.array([-0.28232238, -0.56096673]), self.evaluate(var1), 1e-5,
1e-5)
def testFtrlWithL1(self): def testFtrlWithL1(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -187,8 +191,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([4.0, 3.0], var1.eval()) self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps FTRL # Run 10 steps FTRL
for _ in range(10): for _ in range(10):
@ -197,12 +201,14 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-7.66718769, -10.91273689]), np.array([-7.66718769, -10.91273689]),
var0.eval(), self.evaluate(var0),
rtol=1e-4, rtol=1e-4,
bfloat16_rtol=1e-1, bfloat16_rtol=1e-1,
bfloat16_atol=1e-1) bfloat16_atol=1e-1)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) np.array([-0.93460727, -1.86147261]),
self.evaluate(var1),
rtol=1e-4)
def testFtrlWithL1_L2(self): def testFtrlWithL1_L2(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -219,8 +225,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([4.0, 3.0], var1.eval()) self.assertAllClose([4.0, 3.0], self.evaluate(var1))
# Run 10 steps FTRL # Run 10 steps FTRL
for _ in range(10): for _ in range(10):
@ -228,9 +234,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) np.array([-0.24059935, -0.46829352]),
self.evaluate(var0),
rtol=1e-5)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) np.array([-0.02406147, -0.04830509]),
self.evaluate(var1),
rtol=1e-5)
def testFtrlWithL1_L2_L2Shrinkage(self): def testFtrlWithL1_L2_L2Shrinkage(self):
"""Test the new FTRL op with support for l2 shrinkage. """Test the new FTRL op with support for l2 shrinkage.
@ -254,8 +264,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) self.assertAllCloseAccordingToType([4.0, 3.0], self.evaluate(var1))
# Run 10 steps FTRL # Run 10 steps FTRL
for _ in range(10): for _ in range(10):
@ -263,9 +273,13 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# Validate updated params # Validate updated params
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.22578996, -0.44345799]), var0.eval(), rtol=1e-4) np.array([-0.22578996, -0.44345799]),
self.evaluate(var0),
rtol=1e-4)
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([-0.14378493, -0.13229476]), var1.eval(), rtol=1e-4) np.array([-0.14378493, -0.13229476]),
self.evaluate(var1),
rtol=1e-4)
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
"""Verifies that l2 shrinkage in FTRL does not change lr schedule.""" """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
@ -291,8 +305,8 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
update1 = opt1.apply_gradients([(grads1, var1)]) update1 = opt1.apply_gradients([(grads1, var1)])
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var0))
self.assertAllCloseAccordingToType([1.0, 2.0], var1.eval()) self.assertAllCloseAccordingToType([1.0, 2.0], self.evaluate(var1))
# Run 10 steps FTRL # Run 10 steps FTRL
for _ in range(10): for _ in range(10):
@ -301,7 +315,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase):
# var0 is experiencing L2 shrinkage so it should be smaller than var1 # var0 is experiencing L2 shrinkage so it should be smaller than var1
# in magnitude. # in magnitude.
self.assertTrue((var0.eval()**2 < var1.eval()**2).all()) self.assertTrue((var0.eval()**2 < self.evaluate(var1)**2).all())
accum0 = list(opt0._slots["accum"].values())[0].eval() accum0 = list(opt0._slots["accum"].values())[0].eval()
accum1 = list(opt1._slots["accum"].values())[0].eval() accum1 = list(opt1._slots["accum"].values())[0].eval()
# L2 shrinkage should not change how we update grad accumulator. # L2 shrinkage should not change how we update grad accumulator.

View File

@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval) expected = APlus2B(aval, bval)
with self.cached_session() as sess: with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32) @function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b): def Foo(a, b):
@ -50,7 +50,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b") b = constant_op.constant(bval, name="b")
with self.test_scope(): with self.test_scope():
call_f = Foo(a, b) call_f = Foo(a, b)
result = sess.run(call_f) result = self.evaluate(call_f)
self.assertAllClose(result, expected, rtol=1e-3) self.assertAllClose(result, expected, rtol=1e-3)
def testNestedFunctions(self): def testNestedFunctions(self):
@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval) expected = APlus2B(aval, bval)
with self.cached_session() as sess: with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32) @function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b): def Foo(a, b):
@ -76,7 +76,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b") b = constant_op.constant(bval, name="b")
with self.test_scope(): with self.test_scope():
call_g = Foo(a, b) call_g = Foo(a, b)
result = sess.run(call_g) result = self.evaluate(call_g)
self.assertAllClose(result, expected, rtol=1e-3) self.assertAllClose(result, expected, rtol=1e-3)
def testFunctionMultipleRetvals(self): def testFunctionMultipleRetvals(self):
@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase):
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval) expected = Func(aval, bval)
with self.cached_session() as sess: with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32) @function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b): def Foo(a, b):
@ -100,7 +100,7 @@ class FunctionTest(xla_test.XLATestCase):
b = constant_op.constant(bval, name="b") b = constant_op.constant(bval, name="b")
with self.test_scope(): with self.test_scope():
call_f = Foo(a, b) call_f = Foo(a, b)
result = sess.run(call_f) result = self.evaluate(call_f)
self.assertAllClose(result, expected, rtol=1e-3) self.assertAllClose(result, expected, rtol=1e-3)
def testCompileTimeConstantsInDefun(self): def testCompileTimeConstantsInDefun(self):

View File

@ -75,7 +75,7 @@ def RunMetadataLabels(run_metadata):
def InLabels(labels, substr): def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr.""" """Returns true iff one of the labels contains substr."""
return any([substr in x for x in labels]) return any(substr in x for x in labels)
def MetadataHasXlaRunOp(run_metadata): def MetadataHasXlaRunOp(run_metadata):

View File

@ -33,13 +33,13 @@ class ListDiffTest(xla_test.XLATestCase):
def _testListDiff(self, x, y, out, idx): def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]: for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]:
with self.cached_session() as sess: with self.cached_session():
x_tensor = ops.convert_to_tensor(x, dtype=dtype) x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope(): with self.test_scope():
out_tensor, idx_tensor = array_ops.listdiff( out_tensor, idx_tensor = array_ops.listdiff(
x_tensor, y_tensor, out_idx=index_dtype) x_tensor, y_tensor, out_idx=index_dtype)
tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
self.assertAllEqual(out, tf_out) self.assertAllEqual(out, tf_out)
self.assertAllEqual(idx, tf_idx) self.assertAllEqual(idx, tf_idx)
self.assertEqual(1, out_tensor.get_shape().ndims) self.assertEqual(1, out_tensor.get_shape().ndims)

View File

@ -120,8 +120,8 @@ class LRNTest(xla_test.XLATestCase):
with self.test_scope(): with self.test_scope():
actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image, actual = gen_nn_ops.lrn_grad(out_grads, in_image, out_image,
depth_radius, bias, alpha, beta) depth_radius, bias, alpha, beta)
expected_val = expected.eval() expected_val = self.evaluate(expected)
actual_val = actual.eval() actual_val = self.evaluate(actual)
self.assertAllClose(actual_val, expected_val, rtol=1e-3) self.assertAllClose(actual_val, expected_val, rtol=1e-3)

View File

@ -88,8 +88,8 @@ class LSTMTest(test.TestCase):
(basename, m_prev_scalar, c_prev_scalar, pad_scalar)) (basename, m_prev_scalar, c_prev_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM step. # Initialize variables and run the unrolled LSTM step.
sess.run(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
return sess.run([m, c]) return self.evaluate([m, c])
def testLSTMCell(self): def testLSTMCell(self):
# Run with all-0 weights, no padding. # Run with all-0 weights, no padding.
@ -173,8 +173,8 @@ class LSTMTest(test.TestCase):
(basename, m_init_scalar, c_init_scalar, pad_scalar)) (basename, m_init_scalar, c_init_scalar, pad_scalar))
# Initialize variables and run the unrolled LSTM layer. # Initialize variables and run the unrolled LSTM layer.
sess.run(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
return sess.run(out_seq) return self.evaluate(out_seq)
def testLSTMLayer(self): def testLSTMLayer(self):
# Run with all-0 weights, no padding. # Run with all-0 weights, no padding.

View File

@ -61,37 +61,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
self.assertFalse(slot1 in variables.trainable_variables()) self.assertFalse(slot1 in variables.trainable_variables())
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the momentum accumulators where 0. So we should see a normal # Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate # update: v -= grad * learning_rate
mom_update.run() mom_update.run()
# Check that the momentum accumulators have been updated. # Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) self.assertAllCloseAccordingToType(
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated. # Check that the parameters have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1))
# Step 2: the momentum accumulators contain the previous update. # Step 2: the momentum accumulators contain the previous update.
mom_update.run() mom_update.run()
# Check that the momentum accumulators have been updated. # Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
# Check that the parameters have been updated. # Check that the parameters have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([ np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
]), var0.eval()) ]), self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([ np.array([
2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( 2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
(0.9 * 0.01 + 0.01) * 2.0) 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
]), var1.eval()) ]), self.evaluate(var1))
def testNesterovMomentum(self): def testNesterovMomentum(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -115,8 +121,8 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9)
var1_np, accum1_np = self._update_nesterov_momentum_numpy( var1_np, accum1_np = self._update_nesterov_momentum_numpy(
var1_np, accum1_np, 0.9, 0.1, 0.9) var1_np, accum1_np, 0.9, 0.1, 0.9)
self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, var1.eval()) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testTensorLearningRateAndMomentum(self): def testTensorLearningRateAndMomentum(self):
for dtype in self.float_types: for dtype in self.float_types:
@ -141,37 +147,43 @@ class MomentumOptimizerTest(xla_test.XLATestCase):
self.assertFalse(slot1 in variables.trainable_variables()) self.assertFalse(slot1 in variables.trainable_variables())
# Fetch params to validate initial values # Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval()) self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], var1.eval()) self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the momentum accumulators where 0. So we should see a normal # Step 1: the momentum accumulators where 0. So we should see a normal
# update: v -= grad * learning_rate # update: v -= grad * learning_rate
mom_update.run() mom_update.run()
# Check that the momentum accumulators have been updated. # Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval()) self.assertAllCloseAccordingToType(
self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval()) np.array([0.1, 0.1]), self.evaluate(slot0))
self.assertAllCloseAccordingToType(
np.array([0.01, 0.01]), self.evaluate(slot1))
# Check that the parameters have been updated. # Check that the parameters have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval()) np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval()) np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
self.evaluate(var1))
# Step 2: the momentum accumulators contain the previous update. # Step 2: the momentum accumulators contain the previous update.
mom_update.run() mom_update.run()
# Check that the momentum accumulators have been updated. # Check that the momentum accumulators have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()) np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
self.evaluate(slot0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval()) np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
self.evaluate(slot1))
# Check that the parameters have been updated. # Check that the parameters have been updated.
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([ np.array([
1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0) 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
]), var0.eval()) ]), self.evaluate(var0))
self.assertAllCloseAccordingToType( self.assertAllCloseAccordingToType(
np.array([ np.array([
2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - ( 2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
(0.9 * 0.01 + 0.01) * 2.0) 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)
]), var1.eval()) ]), self.evaluate(var1))
if __name__ == "__main__": if __name__ == "__main__":

Some files were not shown because too many files have changed in this diff Show More